Skip to content

Commit

Permalink
FlipAxis fix (jgraving#57)
Browse files Browse the repository at this point in the history
Co-authored-by: iulia <iuliaisaia@gmail.com>
  • Loading branch information
iuls and igheorghita committed Jun 9, 2020
1 parent c4f054b commit cecdb0c
Showing 1 changed file with 52 additions and 75 deletions.
127 changes: 52 additions & 75 deletions deepposekit/augment/FlipAxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@
# limitations under the License.

import numpy as np
import imgaug.augmenters as iaa
import six.moves as sm
import h5py

from deepposekit.io.BaseGenerator import BaseGenerator
from imgaug.augmenters import meta
from imgaug import parameters as iap

__all__ = ["FlipAxis"]


class FlipAxis(iaa.Flipud):
class FlipAxis(meta.Augmenter):
""" Flips the input image and keypoints across an axis.
A generalized class for flipping images and keypoints
Expand All @@ -39,103 +37,82 @@ class FlipAxis(iaa.Flipud):
This can be a deepposekit.io.BaseGenerator for annotations
or an array of integers specifying which keypoint indices
to swap.
p: int, default 0.5
The probability that an image is flipped
axis: int, default 0
Axis over which images are flipped
axis=0 flips up-down (np.flipud)
axis=1 flips left-right (np.fliplr)
seed: None or int or np.random.RandomState, default None
The random state for the augmenter.
name: None or str, default None
Name given to the Augmenter object. The name is used in print().
If left as None, will print 'UnnamedX'
deterministic: bool, default False
If set to true, each batch will be augmented the same way.
random_state: None or int or np.random.RandomState, default None
The random state for the augmenter.
Attributes
----------
p: int
The probability that an image is flipped
axis: int
The axis to reflect the image.
swap_index: array
The keypoint indices to swap when the image is flipped
"""

def __init__(
self,
swap_index,
p=0.5,
axis=0,
name=None,
deterministic=False,
random_state=None,
):

super(FlipAxis, self).__init__(
p=p, name=name, deterministic=deterministic, random_state=random_state
)

def __init__(self, swap_index, p=0.5, axis=0, seed=None, name=None, deterministic=False):
super(FlipAxis, self).__init__(seed=seed, name=name, random_state="deprecated", deterministic=deterministic)
self.p = iap.handle_probability_param(p, "p")
self.axis = axis
if isinstance(swap_index, BaseGenerator):
if hasattr(swap_index, "swap_index"):
self.swap_index = swap_index.swap_index
elif isinstance(swap_index, np.ndarray):
self.swap_index = swap_index


def _augment_batch_(self, batch, random_state, parents, hooks):
samples = self.p.draw_samples((batch.nb_rows,),
random_state=random_state)
for i, sample in enumerate(samples):
if sample >= 0.5:

if batch.images is not None:
if self.axis == 0:
batch.images[i] = np.flipud(batch.images[i])
if self.axis == 1:
batch.images[i] = np.fliplr(batch.images[i])


def _augment_images(self, images, random_state, parents, hooks):
""" Augments the images
Handles the augmentation over a specified axis
Returns
-------
images: array
Array of augmented images.
"""
nb_images = len(images)
samples = self.p.draw_samples((nb_images,), random_state=random_state)
for i in sm.xrange(nb_images):
if samples[i] == 1:
if self.axis == 1:
images[i] = np.fliplr(images[i])
elif self.axis == 0:
images[i] = np.flipud(images[i])
self.samples = samples
return images

def _augment_keypoints(self, keypoints_on_images, random_state, parents, hooks):
""" Augments the keypoints
Handles the augmentation over a specified axis
and swaps the keypoint labels using swap_index.
For example, the left leg will be swapped with the right leg
This is accomplished by reordering the keypoints.
Returns
-------
keypoints_on_images: array
Array of new coordinates of the keypoints.
"""
nb_images = len(keypoints_on_images)
samples = self.p.draw_samples((nb_images,), random_state=random_state)
for i, keypoints_on_image in enumerate(keypoints_on_images):
if samples[i] == 1:
for keypoint in keypoints_on_image.keypoints:
if batch.keypoints is not None:
kpsoi = batch.keypoints[i]
if self.axis == 0:
height = kpsoi.shape[0]
for kp in kpsoi.keypoints:
kp.y = (height-1) - kp.y
if self.axis == 1:
width = keypoints_on_image.shape[1]
keypoint.x = (width - 1) - keypoint.x
elif self.axis == 0:
height = keypoints_on_image.shape[0]
keypoint.y = (height - 1) - keypoint.y
swapped = keypoints_on_image.keypoints.copy()
for r in range(len(keypoints_on_image.keypoints)):
idx = self.swap_index[r]
if idx >= 0:
keypoints_on_image.keypoints[r] = swapped[idx]
return keypoints_on_images
width = kpsoi.shape[1]
for kp in kpsoi.keypoints:
kp.x = (width-1) - kp.x
swapped = kpsoi.keypoints.copy()
for r in range(len(kpsoi.keypoints)):
idx = self.swap_index[r]
if idx >= 0:
kpsoi.keypoints[r] = swapped[idx]

return batch

def get_parameters(self):
"""See :func:`~imgaug.augmenters.meta.Augmenter.get_parameters`."""
return [self.p, self.axis, self.swap_index]

0 comments on commit cecdb0c

Please sign in to comment.