Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 747 add preprocessing #804

Merged
merged 41 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
43176f2
wip
Louis-Dupont Mar 26, 2023
5a0023b
move to imageprocessors
Louis-Dupont Mar 26, 2023
1aacdfa
Merge branch 'master' into feature/SG-747-add_image_processor
Louis-Dupont Mar 26, 2023
89c48a5
wip
Louis-Dupont Mar 27, 2023
6958813
add back changes
Louis-Dupont Mar 27, 2023
4ae57b1
making it work fully for yolox and almost for ppyoloe
Louis-Dupont Mar 27, 2023
2700b80
minor change
Louis-Dupont Mar 27, 2023
b48c596
working for det
Louis-Dupont Mar 28, 2023
e5366c5
Merge branch 'master' into feature/SG-747-add_preprocessing
Louis-Dupont Mar 28, 2023
0ac4fe8
cleaning
Louis-Dupont Mar 28, 2023
24c16c8
clean
Louis-Dupont Mar 28, 2023
2735cf8
undo
Louis-Dupont Mar 28, 2023
3587cee
replace empty with none
Louis-Dupont Mar 28, 2023
4a50611
Merge branch 'master' into feature/SG-747-add_preprocessing
Louis-Dupont Mar 28, 2023
6a4250e
add _get_shift_params
Louis-Dupont Mar 28, 2023
061aa5d
minor doc change
Louis-Dupont Mar 28, 2023
0031494
Merge branch 'master' into feature/SG-747-add_preprocessing
Louis-Dupont Mar 29, 2023
2464398
replace pydantic with dataclasses and fix typing
Louis-Dupont Mar 29, 2023
d4c0774
add docstrings
Louis-Dupont Mar 29, 2023
cf19765
doc improvment and use get_shift_params in transforms
Louis-Dupont Mar 29, 2023
7e8ad22
add tests
Louis-Dupont Mar 29, 2023
90f708e
improve comment
Louis-Dupont Mar 29, 2023
8830ba9
rename
Louis-Dupont Mar 29, 2023
efd58d4
Merge branch 'master' into feature/SG-747-add_preprocessing
Louis-Dupont Mar 29, 2023
74379c6
add option to keep ratio in rescale
Louis-Dupont Mar 29, 2023
efbde36
Merge branch 'master' into feature/SG-747-add_preprocessing
Louis-Dupont Mar 29, 2023
efd023e
make functions private
Louis-Dupont Mar 29, 2023
008b77b
remove DetectionPaddedRescale
Louis-Dupont Mar 29, 2023
77addfa
fix doc
Louis-Dupont Mar 29, 2023
d6c0f9b
add fixes
Louis-Dupont Mar 30, 2023
0cb58e2
improve _get_center_padding_params output
Louis-Dupont Mar 30, 2023
f0baed7
minor fix
Louis-Dupont Mar 30, 2023
1a32cf2
add empty bbox test for rescale_bboxes
Louis-Dupont Mar 30, 2023
dcfd902
finalizing _DetectionPadding, DetectionCenterPadding and DetectionBot…
Louis-Dupont Mar 30, 2023
858ecc0
remove _pad_to_side
Louis-Dupont Mar 30, 2023
a19f591
split rescale into 2 classes
Louis-Dupont Mar 30, 2023
3229c54
minor addition
Louis-Dupont Mar 30, 2023
b012d46
Add DetectionPrediction object
Louis-Dupont Apr 2, 2023
3571780
simplify DetectionPrediction class
Louis-Dupont Apr 3, 2023
7b73edb
add round and don't rescale if no change required
Louis-Dupont Apr 3, 2023
68e5097
Merge branch 'master' into feature/SG-747-add_preprocessing
BloodAxe Apr 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ anchors:
yolo_type: 'yoloX'

depth_mult_factor: 0.33
width_mult_factor: 0.5
width_mult_factor: 0.5
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ val_dataset_params:
mean: [0.4802, 0.4481, 0.3975]
std: [0.2770, 0.2691, 0.2821]

_convert_: all
_convert_: all
193 changes: 193 additions & 0 deletions src/super_gradients/training/transforms/processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from typing import Tuple, List, Union
from abc import ABC, abstractmethod
from dataclasses import dataclass

import numpy as np

from super_gradients.training.transforms.utils import (
_rescale_image,
_rescale_bboxes,
_pad_image_on_side,
_get_center_padding_params,
_pad_image,
_shift_bboxes,
)


@dataclass
class ProcessingMetadata(ABC):
"""Metadata including information to postprocess a prediction."""


@dataclass
class ComposeProcessingMetadata(ProcessingMetadata):
metadata_lst: List[Union[None, ProcessingMetadata]]


@dataclass
class DetectionPadToSizeMetadata(ProcessingMetadata):
pad_top: float
pad_left: float


@dataclass
class RescaleMetadata(ProcessingMetadata):
original_shape: Tuple[int, int]
scale_factor_h: float
scale_factor_w: float


class Processing(ABC):
"""Interface for preprocessing and postprocessing methods that are
used to prepare images for a model and process the model's output.

Subclasses should implement the `preprocess_image` and `postprocess_predictions`
methods according to the specific requirements of the model and task.
"""

@abstractmethod
def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, Union[None, ProcessingMetadata]]:
"""Processing an image, before feeding it to the network. Expected to be in (H, W, C) or (H, W)."""
pass

@abstractmethod
def postprocess_predictions(self, predictions: np.ndarray, metadata: Union[None, ProcessingMetadata]) -> np.ndarray:
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
"""Postprocess the model output predictions."""
pass


class ComposeProcessing(Processing):
"""Compose a list of Processing objects into a single Processing object."""

def __init__(self, processings: List[Processing]):
self.processings = processings

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, ComposeProcessingMetadata]:
"""Processing an image, before feeding it to the network."""
processed_image, metadata_lst = image.copy(), []
for processing in self.processings:
processed_image, metadata = processing.preprocess_image(image=processed_image)
metadata_lst.append(metadata)
return processed_image, ComposeProcessingMetadata(metadata_lst=metadata_lst)

def postprocess_predictions(self, predictions: np.ndarray, metadata: ComposeProcessingMetadata) -> np.ndarray:
"""Postprocess the model output predictions."""
postprocessed_predictions = predictions
for processing, metadata in zip(self.processings[::-1], metadata.metadata_lst[::-1]):
postprocessed_predictions = processing.postprocess_predictions(postprocessed_predictions, metadata)
return postprocessed_predictions


class ImagePermute(Processing):
"""Permute the image dimensions.

:param permutation: Specify new order of dims. Default value (2, 0, 1) suitable for converting from HWC to CHW format.
"""

def __init__(self, permutation: Tuple[int, int, int] = (2, 0, 1)):
self.permutation = permutation

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
processed_image = np.ascontiguousarray(image.transpose(*self.permutation))
return processed_image, None

def postprocess_predictions(self, predictions: np.ndarray, metadata: None) -> np.ndarray:
return predictions


class NormalizeImage(Processing):
"""Normalize an image based on means and standard deviation.

:param mean: Mean values for each channel.
:param std: Standard deviation values for each channel.
"""

def __init__(self, mean: List[float], std: List[float]):
self.mean = np.array(mean).reshape((1, 1, -1)).astype(np.float32)
self.std = np.array(std).reshape((1, 1, -1)).astype(np.float32)

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
return (image - self.mean) / self.std, None

def postprocess_predictions(self, predictions: np.ndarray, metadata: None) -> np.ndarray:
return predictions


class DetectionCenterPadding(Processing):
"""Preprocessing transform to pad image and bboxes to `output_shape` shape (H, W).
Center padding, so that input image with bboxes located in the center of the produced image.

Note: This transformation assume that dimensions of input image is equal or less than `output_shape`.

:param output_shape: Output image shape (H, W)
:param pad_value: Padding value for image
"""

def __init__(self, output_shape: Tuple[int, int], pad_value: int):
self.output_shape = output_shape
self.pad_value = pad_value

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, DetectionPadToSizeMetadata]:
pad_top, pad_bot, pad_left, pad_right = _get_center_padding_params(input_shape=image.shape, output_shape=self.output_shape)
processed_image = _pad_image(image, (pad_top, pad_bot), (pad_left, pad_right), self.pad_value)

return processed_image, DetectionPadToSizeMetadata(pad_top=pad_top, pad_left=pad_left)

def postprocess_predictions(self, predictions: np.ndarray, metadata: DetectionPadToSizeMetadata) -> np.ndarray:
return _shift_bboxes(targets=predictions, shift_h=-metadata.pad_top, shift_w=-metadata.pad_left)


class DetectionSidePadding(Processing):
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
"""Preprocessing transform to pad image and bboxes to `output_shape` shape (H, W).
Side padding, so that input image with bboxes will located on the side. Bboxes won't be affected.

Note: This transformation assume that dimensions of input image is equal or less than `output_shape`.

:param output_shape: Output image shape (H, W)
:param pad_value: Padding value for image
"""

def __init__(self, output_shape: Tuple[int, int], pad_value: int):
self.output_shape = output_shape
self.pad_value = pad_value

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
processed_image = _pad_image_on_side(image, output_shape=self.output_shape, pad_val=self.pad_value)
return processed_image, None

def postprocess_predictions(self, predictions: np.ndarray, metadata: None) -> np.ndarray:
return predictions


class _Rescale(Processing, ABC):
"""Resize image to given image dimensions WITHOUT preserving aspect ratio.

:param output_shape: (H, W)
"""

def __init__(self, output_shape: Tuple[int, int], keep_aspect_ratio: bool):
self.output_shape = output_shape
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
self.keep_aspect_ratio = keep_aspect_ratio

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, RescaleMetadata]:
rescale_shape = self.output_shape
scale_factor_h, scale_factor_w = rescale_shape[0] / image.shape[0], rescale_shape[1] / image.shape[1]

if self.keep_aspect_ratio:
scale_factor = min(scale_factor_h, scale_factor_w)
scale_factor_h, scale_factor_w = (scale_factor, scale_factor)
rescale_shape = (int(image.shape[0] * scale_factor_w), int(image.shape[1] * scale_factor_h))

rescaled_image = _rescale_image(image, target_shape=rescale_shape)

return rescaled_image, RescaleMetadata(original_shape=image.shape[:2], scale_factor_h=scale_factor_h, scale_factor_w=scale_factor_w)


class DetectionRescale(_Rescale):
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
def postprocess_predictions(self, predictions: np.ndarray, metadata: RescaleMetadata) -> np.ndarray:
return _rescale_bboxes(targets=predictions, scale_factors=(1 / metadata.scale_factor_h, 1 / metadata.scale_factor_w))


class SegmentationRescale(_Rescale):
def postprocess_predictions(self, predictions: np.ndarray, metadata: RescaleMetadata) -> np.ndarray:
return _rescale_image(predictions, target_shape=metadata.original_shape)
Loading