From cecdb0c8c364ea049a3b705275ae71a2f366d4da Mon Sep 17 00:00:00 2001 From: iuls <54384769+iuls@users.noreply.github.com> Date: Tue, 9 Jun 2020 06:24:45 -0400 Subject: [PATCH] FlipAxis fix (#57) Co-authored-by: iulia --- deepposekit/augment/FlipAxis.py | 127 +++++++++++++------------------- 1 file changed, 52 insertions(+), 75 deletions(-) diff --git a/deepposekit/augment/FlipAxis.py b/deepposekit/augment/FlipAxis.py index 3be5563..08529c7 100644 --- a/deepposekit/augment/FlipAxis.py +++ b/deepposekit/augment/FlipAxis.py @@ -14,16 +14,14 @@ # limitations under the License. import numpy as np -import imgaug.augmenters as iaa -import six.moves as sm -import h5py - from deepposekit.io.BaseGenerator import BaseGenerator +from imgaug.augmenters import meta +from imgaug import parameters as iap __all__ = ["FlipAxis"] -class FlipAxis(iaa.Flipud): +class FlipAxis(meta.Augmenter): """ Flips the input image and keypoints across an axis. A generalized class for flipping images and keypoints @@ -39,103 +37,82 @@ class FlipAxis(iaa.Flipud): This can be a deepposekit.io.BaseGenerator for annotations or an array of integers specifying which keypoint indices to swap. + + p: int, default 0.5 + The probability that an image is flipped axis: int, default 0 Axis over which images are flipped axis=0 flips up-down (np.flipud) axis=1 flips left-right (np.fliplr) - + + seed: None or int or np.random.RandomState, default None + The random state for the augmenter. + name: None or str, default None Name given to the Augmenter object. The name is used in print(). If left as None, will print 'UnnamedX' - + deterministic: bool, default False If set to true, each batch will be augmented the same way. - random_state: None or int or np.random.RandomState, default None - The random state for the augmenter. Attributes ---------- + p: int + The probability that an image is flipped + axis: int The axis to reflect the image. swap_index: array The keypoint indices to swap when the image is flipped + """ - def __init__( - self, - swap_index, - p=0.5, - axis=0, - name=None, - deterministic=False, - random_state=None, - ): - - super(FlipAxis, self).__init__( - p=p, name=name, deterministic=deterministic, random_state=random_state - ) - + def __init__(self, swap_index, p=0.5, axis=0, seed=None, name=None, deterministic=False): + super(FlipAxis, self).__init__(seed=seed, name=name, random_state="deprecated", deterministic=deterministic) + self.p = iap.handle_probability_param(p, "p") self.axis = axis if isinstance(swap_index, BaseGenerator): if hasattr(swap_index, "swap_index"): self.swap_index = swap_index.swap_index elif isinstance(swap_index, np.ndarray): self.swap_index = swap_index + + + def _augment_batch_(self, batch, random_state, parents, hooks): + samples = self.p.draw_samples((batch.nb_rows,), + random_state=random_state) + for i, sample in enumerate(samples): + if sample >= 0.5: + + if batch.images is not None: + if self.axis == 0: + batch.images[i] = np.flipud(batch.images[i]) + if self.axis == 1: + batch.images[i] = np.fliplr(batch.images[i]) + - def _augment_images(self, images, random_state, parents, hooks): - """ Augments the images - - Handles the augmentation over a specified axis - - Returns - ------- - images: array - Array of augmented images. - - """ - nb_images = len(images) - samples = self.p.draw_samples((nb_images,), random_state=random_state) - for i in sm.xrange(nb_images): - if samples[i] == 1: - if self.axis == 1: - images[i] = np.fliplr(images[i]) - elif self.axis == 0: - images[i] = np.flipud(images[i]) - self.samples = samples - return images - - def _augment_keypoints(self, keypoints_on_images, random_state, parents, hooks): - """ Augments the keypoints - - Handles the augmentation over a specified axis - and swaps the keypoint labels using swap_index. - For example, the left leg will be swapped with the right leg - This is accomplished by reordering the keypoints. - - Returns - ------- - keypoints_on_images: array - Array of new coordinates of the keypoints. - - """ - nb_images = len(keypoints_on_images) - samples = self.p.draw_samples((nb_images,), random_state=random_state) - for i, keypoints_on_image in enumerate(keypoints_on_images): - if samples[i] == 1: - for keypoint in keypoints_on_image.keypoints: + if batch.keypoints is not None: + kpsoi = batch.keypoints[i] + if self.axis == 0: + height = kpsoi.shape[0] + for kp in kpsoi.keypoints: + kp.y = (height-1) - kp.y if self.axis == 1: - width = keypoints_on_image.shape[1] - keypoint.x = (width - 1) - keypoint.x - elif self.axis == 0: - height = keypoints_on_image.shape[0] - keypoint.y = (height - 1) - keypoint.y - swapped = keypoints_on_image.keypoints.copy() - for r in range(len(keypoints_on_image.keypoints)): - idx = self.swap_index[r] - if idx >= 0: - keypoints_on_image.keypoints[r] = swapped[idx] - return keypoints_on_images + width = kpsoi.shape[1] + for kp in kpsoi.keypoints: + kp.x = (width-1) - kp.x + swapped = kpsoi.keypoints.copy() + for r in range(len(kpsoi.keypoints)): + idx = self.swap_index[r] + if idx >= 0: + kpsoi.keypoints[r] = swapped[idx] + + return batch + + def get_parameters(self): + """See :func:`~imgaug.augmenters.meta.Augmenter.get_parameters`.""" + return [self.p, self.axis, self.swap_index] \ No newline at end of file