Skip to content

Commit

Permalink
Fixed bug in _pad_image that did not support pad_value=(R,B,G) input (#…
Browse files Browse the repository at this point in the history
…1599)

* Fixed bug in _pad_image that did not support pad_value=(R,B,G) input

* Added checking for pad_value when input is image of HW shape

* More efficient padding implementation
  • Loading branch information
BloodAxe committed Nov 6, 2023
1 parent 4d3abe1 commit 29dea7a
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 6 deletions.
37 changes: 32 additions & 5 deletions src/super_gradients/training/transforms/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Tuple
import numbers
import typing
from typing import Tuple, Union
from dataclasses import dataclass
import cv2

Expand Down Expand Up @@ -106,22 +108,47 @@ def _get_bottom_right_padding_coordinates(input_shape: Tuple[int, int], output_s
return PaddingCoordinates(top=0, bottom=pad_height, left=0, right=pad_width)


def _pad_image(image: np.ndarray, padding_coordinates: PaddingCoordinates, pad_value: int) -> np.ndarray:
def _pad_image(image: np.ndarray, padding_coordinates: PaddingCoordinates, pad_value: Union[int, Tuple[int, ...]]) -> np.ndarray:
"""Pad an image.
:param image: Image to shift. (H, W, C) or (H, W).
:param pad_h: Tuple of (padding_top, padding_bottom).
:param pad_w: Tuple of (padding_left, padding_right).
:param pad_value: Padding value
:param pad_value: Padding value. Can be a single scalar (Same value for all channels) or a tuple of values.
In the latter case, the tuple length must be equal to the number of channels.
:return: Image shifted according to padding coordinates.
"""
pad_h = (padding_coordinates.top, padding_coordinates.bottom)
pad_w = (padding_coordinates.left, padding_coordinates.right)

if len(image.shape) == 3:
return np.pad(image, (pad_h, pad_w, (0, 0)), "constant", constant_values=pad_value)
_, _, num_channels = image.shape

if isinstance(pad_value, numbers.Number):
pad_value = tuple([pad_value] * num_channels)
else:
if isinstance(pad_value, typing.Sized) and len(pad_value) != num_channels:
raise ValueError(f"A pad_value tuple ({pad_value} length should be {num_channels} for an image with {num_channels} channels")

pad_value = tuple(pad_value)

constant_values = ((pad_value, pad_value), (pad_value, pad_value), (0, 0))
padding_values = (pad_h, pad_w, (0, 0))
else:
return np.pad(image, (pad_h, pad_w), "constant", constant_values=pad_value)
if isinstance(pad_value, numbers.Number):
pass
elif isinstance(pad_value, typing.Sized):
if len(pad_value) != 1:
raise ValueError(f"A pad_value tuple ({pad_value} length should be 1 for a grayscale image")
else:
(pad_value,) = pad_value # Unpack to a single scalar
else:
raise ValueError(f"Unsupported pad_value type {type(pad_value)}")

constant_values = pad_value
padding_values = (pad_h, pad_w)

return np.pad(image, pad_width=padding_values, mode="constant", constant_values=constant_values)


def _shift_bboxes(targets: np.array, shift_w: float, shift_h: float) -> np.array:
Expand Down
45 changes: 44 additions & 1 deletion tests/unit_tests/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import cv2
import matplotlib.pyplot as plt
import numpy as np
from omegaconf import ListConfig

from super_gradients.training.transforms import KeypointsMixup, KeypointsCompose
from super_gradients.training.transforms.keypoint_transforms import (
Expand Down Expand Up @@ -244,7 +245,7 @@ def test_rescale_bboxes(self):
rescaled_bboxes = _rescale_bboxes(targets=bboxes, scale_factors=(sy, sx))
np.testing.assert_array_equal(rescaled_bboxes, expected_bboxes)

def test_pad_image(self):
def test_pad_image_with_constant(self):
image = np.random.randint(0, 256, size=(640, 480, 3), dtype=np.uint8)
padding_coordinates = PaddingCoordinates(top=80, bottom=80, left=60, right=60)
pad_value = 0
Expand All @@ -258,6 +259,48 @@ def test_pad_image(self):
self.assertTrue((shifted_image[:, : padding_coordinates.left, :] == pad_value).all())
self.assertTrue((shifted_image[:, -padding_coordinates.right :, :] == pad_value).all())

def test_pad_image_with_tuple(self):
image = np.random.randint(0, 256, size=(640, 480, 3), dtype=np.uint8)
padding_coordinates = PaddingCoordinates(top=80, bottom=80, left=60, right=60)
pad_value = (1, 2, 3)
shifted_image = _pad_image(image, padding_coordinates, pad_value)

# Check if the shifted image has the correct shape
self.assertEqual(shifted_image.shape, (800, 600, 3))
# Check if the padding values are correct
self.assertTrue((shifted_image[: padding_coordinates.top, :, :] == pad_value).all())
self.assertTrue((shifted_image[-padding_coordinates.bottom :, :, :] == pad_value).all())
self.assertTrue((shifted_image[:, : padding_coordinates.left, :] == pad_value).all())
self.assertTrue((shifted_image[:, -padding_coordinates.right :, :] == pad_value).all())

def test_pad_image_with_listconfig(self):
image = np.random.randint(0, 256, size=(640, 480, 3), dtype=np.uint8)
padding_coordinates = PaddingCoordinates(top=80, bottom=80, left=60, right=60)
pad_value = ListConfig([1, 2, 3])
shifted_image = _pad_image(image, padding_coordinates, pad_value)

# Check if the shifted image has the correct shape
self.assertEqual(shifted_image.shape, (800, 600, 3))
# Check if the padding values are correct
self.assertTrue((shifted_image[: padding_coordinates.top, :, :] == pad_value).all())
self.assertTrue((shifted_image[-padding_coordinates.bottom :, :, :] == pad_value).all())
self.assertTrue((shifted_image[:, : padding_coordinates.left, :] == pad_value).all())
self.assertTrue((shifted_image[:, -padding_coordinates.right :, :] == pad_value).all())

def test_pad_grayscale_image(self):
image = np.random.randint(0, 256, size=(640, 480), dtype=np.uint8)
padding_coordinates = PaddingCoordinates(top=80, bottom=80, left=60, right=60)
pad_value = 1
shifted_image = _pad_image(image, padding_coordinates, pad_value)

# Check if the shifted image has the correct shape
self.assertEqual(shifted_image.shape, (800, 600))
# Check if the padding values are correct
self.assertTrue((shifted_image[: padding_coordinates.top, :] == pad_value).all())
self.assertTrue((shifted_image[-padding_coordinates.bottom :, :] == pad_value).all())
self.assertTrue((shifted_image[:, : padding_coordinates.left] == pad_value).all())
self.assertTrue((shifted_image[:, -padding_coordinates.right :] == pad_value).all())

def test_shift_bboxes(self):
bboxes = np.array([[10, 20, 50, 60, 1], [30, 40, 80, 90, 2]], dtype=np.float32)
shift_w, shift_h = 60, 80
Expand Down

0 comments on commit 29dea7a

Please sign in to comment.