Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add detection vertical flip augmentation #1234

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/super_gradients/common/object_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Transforms:
DetectionRGB2BGR = "DetectionRGB2BGR"
DetectionRandomRotate90 = "DetectionRandomRotate90"
DetectionHorizontalFlip = "DetectionHorizontalFlip"
DetectionVerticalFlip = "DetectionVerticalFlip"
DetectionRescale = "DetectionRescale"
DetectionPadToSize = "DetectionPadToSize"
DetectionImagePermute = "DetectionImagePermute"
Expand Down
92 changes: 75 additions & 17 deletions src/super_gradients/training/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,10 @@ def __call__(self, sample: dict) -> dict:
cp_sample = sample["additional_samples"][0]
img, cp_labels = cp_sample["image"], cp_sample["target"]
cp_boxes = cp_labels[:, :4]

img, cp_boxes = _mirror(img, cp_boxes, self.flip_prob)
if random.random() < self.prob:
_, width, _ = img.shape
img = _flip_horizontal_image(img)
cp_boxes = _flip_horizontal_boxes(cp_boxes, width)
# PLUG IN TARGET THE FLIPPED BOXES
cp_labels[:, :4] = cp_boxes

Expand Down Expand Up @@ -826,13 +828,48 @@ def __init__(self, prob: float, max_targets: Optional[int] = None):

def __call__(self, sample):
image, targets = sample["image"], sample["target"]
crowd_targets = sample.get("crowd_targets")
if len(targets) == 0:
targets = np.zeros((0, 5), dtype=np.float32)
boxes = targets[:, :4]
image, boxes = _mirror(image, boxes, self.prob)
targets[:, :4] = boxes
if random.random() < self.prob:
image = _flip_horizontal_image(image)
_, width, _ = image.shape
targets[:, :4] = _flip_horizontal_boxes(targets[:, :4], width)
if crowd_targets is not None:
crowd_targets = _flip_horizontal_boxes(crowd_targets, width)
sample["image"] = image
sample["target"] = targets
sample["crowd_targets"] = crowd_targets
return sample


@register_transform(Transforms.DetectionVerticalFlip)
class DetectionVerticalFlip(DetectionTransform):
"""
Vertical Flip for Detection

:param prob: Probability of applying vertical flip
"""

def __init__(self, prob: float, max_targets: Optional[int] = None):
super(DetectionVerticalFlip, self).__init__()
_max_targets_deprication(max_targets)
self.prob = prob

def __call__(self, sample):
image, targets = sample["image"], sample["target"]
crowd_targets = sample.get("crowd_targets")
if len(targets) == 0:
targets = np.zeros((0, 5), dtype=np.float32)
if random.random() < self.prob:
image = _flip_vertical_image(image)
height, _, _ = image.shape
targets[:, :4] = _flip_vertical_boxes(targets[:, :4], height)
if crowd_targets is not None:
crowd_targets[:, :4] = _flip_vertical_boxes(crowd_targets[:, :4], height)
sample["image"] = image
sample["target"] = targets
sample["crowd_targets"] = crowd_targets
return sample


Expand Down Expand Up @@ -1277,21 +1314,42 @@ def _filter_box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1):
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) & (ar < ar_thr) # candidates


def _mirror(image, boxes, prob=0.5):
def _flip_horizontal_image(image: np.ndarray) -> np.ndarray:
"""
Horizontally flips image
:param image: image to be flipped.
:return: flipped_image
"""
return image[:, ::-1]


def _flip_horizontal_boxes(boxes: np.ndarray, img_width: int) -> np.ndarray:
"""
Horizontally flips bboxes
:param boxes: bboxes to be flipped. (xyxy format)
:return: flipped_boxes
"""
boxes[:, [0, 2]] = img_width - boxes[:, [2, 0]]
return boxes


def _flip_vertical_image(image: np.ndarray) -> np.ndarray:
"""
Vertically flips image
:param image: image to be flipped.
:return: flipped_image
"""
Horizontal flips image and bboxes with probability prob.
return image[::-1, :]

:param image: (np.array) image to be flipped.
:param boxes: (np.array) bboxes to be modified.
:param prob: probability to perform flipping.
:return: flipped_image, flipped_bboxes

def _flip_vertical_boxes(boxes: np.ndarray, img_height: int) -> np.ndarray:
"""
Vertically flips bboxes
:param boxes: bboxes to be flipped. (xyxy format)
:return: flipped_boxes
"""
flipped_boxes = boxes.copy()
_, width, _ = image.shape
if random.random() < prob:
image = image[:, ::-1]
flipped_boxes[:, 0::2] = width - boxes[:, 2::-2]
return image, flipped_boxes
boxes[:, [1, 3]] = img_height - boxes[:, [3, 1]]
return boxes


def augment_hsv(img: np.array, hgain: float, sgain: float, vgain: float, bgr_channels=(0, 1, 2)):
Expand Down
85 changes: 84 additions & 1 deletion tests/unit_tests/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
KeypointsPadIfNeeded,
KeypointsLongestMaxSize,
)
from super_gradients.training.transforms.transforms import DetectionImagePermute, DetectionPadToSize
from super_gradients.training.transforms.transforms import (
DetectionImagePermute,
DetectionPadToSize,
DetectionHorizontalFlip,
DetectionVerticalFlip,
)

from super_gradients.training.transforms.utils import (
_rescale_image,
Expand Down Expand Up @@ -140,6 +145,84 @@ def test_rescale_image(self):
# Check if the rescaled image has the correct target shape
self.assertEqual(rescaled_image.shape[:2], target_shape)

def test_detection_horizontal_flip(self):
aug = DetectionHorizontalFlip(prob=1)
image = np.random.rand(100, 100, 3)
image_original = image.copy()
# [x0, y0, x1, y1]
bboxes = np.array(
(
(10, 10, 20, 20),
(90, 90, 100, 100),
)
)
bboxes_expected = np.array(
(
(80, 10, 90, 20),
(0, 90, 10, 100),
)
)

# run transform
sample = {"image": image}
sample["target"] = bboxes
sample["crowd_targets"] = bboxes.copy()
output = aug(sample)
image = output["image"]
target = output["target"]
crowd_targets = output["crowd_targets"]

# check image hasn't changed shape
self.assertEqual(image.shape, image_original.shape)

# check the first two cols of original image
# match last two rows of flipped image
self.assertTrue(np.array_equal(image_original[:, 0], image[:, -1]))
self.assertTrue(np.array_equal(image_original[:, 1], image[:, -2]))

# check bboxes as expected
self.assertTrue(np.array_equal(target, bboxes_expected))
self.assertTrue(np.array_equal(crowd_targets, bboxes_expected))

def test_detection_vertical_flip(self):
aug = DetectionVerticalFlip(prob=1)
image = np.random.rand(100, 100, 3)
image_original = image.copy()
# [x0, y0, x1, y1]
bboxes = np.array(
(
(10, 10, 20, 20),
(90, 90, 100, 100),
)
)
bboxes_expected = np.array(
(
(10, 80, 20, 90),
(90, 0, 100, 10),
)
)

# run transform
sample = {"image": image}
sample["target"] = bboxes
sample["crowd_targets"] = bboxes.copy()
output = aug(sample)
image = output["image"]
target = output["target"]
crowd_targets = output["crowd_targets"]

# check image hasn't changed shape
self.assertEqual(image.shape, image_original.shape)

# check top two rows of original image
# matches bottom rows of flipped image
self.assertTrue(np.array_equal(image_original[0], image[-1]))
self.assertTrue(np.array_equal(image_original[1], image[-2]))

# check bboxes as expected
self.assertTrue(np.array_equal(target, bboxes_expected))
self.assertTrue(np.array_equal(crowd_targets, bboxes_expected))

def test_rescale_bboxes(self):
sy, sx = (2.0, 0.5)

Expand Down