Skip to content

Commit

Permalink
AutoAugmentDetection -> _AutoAugmentDetection before we validate the …
Browse files Browse the repository at this point in the history
…results

Added another smoke test
  • Loading branch information
vfdev-5 committed May 17, 2023
1 parent 791d620 commit 3afeb30
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 11 deletions.
39 changes: 35 additions & 4 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ class TestSmoke:
(transforms.AutoAugment(), auto_augment_adapter),
(transforms.RandAugment(), auto_augment_adapter),
(transforms.TrivialAugmentWide(), auto_augment_adapter),
(transforms.AutoAugmentDetection("v0"), auto_augment_detection_adapter),
(transforms.AutoAugmentDetection("v1"), auto_augment_detection_adapter),
(transforms.AutoAugmentDetection("v2"), auto_augment_detection_adapter),
(transforms.AutoAugmentDetection("v3"), auto_augment_detection_adapter),
(transforms._AutoAugmentDetection("v0"), auto_augment_detection_adapter),
(transforms._AutoAugmentDetection("v1"), auto_augment_detection_adapter),
(transforms._AutoAugmentDetection("v2"), auto_augment_detection_adapter),
(transforms._AutoAugmentDetection("v3"), auto_augment_detection_adapter),
(transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
(transforms.Grayscale(), None),
(transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
Expand Down Expand Up @@ -336,6 +336,37 @@ def test_common(self, transform, adapter, container_type, image_or_video, device
def test_auto_augment(self, transform, input):
transform(input)

@pytest.mark.parametrize(
"transform",
[
transforms._AutoAugmentDetection("v0"),
transforms._AutoAugmentDetection("v1"),
transforms._AutoAugmentDetection("v2"),
transforms._AutoAugmentDetection("v3"),
],
)
@pytest.mark.parametrize("seed", range(10))
def test_auto_augment_detection(self, transform, seed):
torch.manual_seed(seed)
image = datapoints.Image(torch.randint(0, 256, size=(3, 480, 640), dtype=torch.uint8))
boxes = torch.tensor(
[
[388.3100, 38.8300, 638.5600, 480.0000],
[82.6800, 314.7300, 195.4400, 445.7400],
[199.1000, 214.1700, 316.4100, 399.2800],
[159.8400, 183.6800, 216.8700, 273.1200],
[15.8000, 265.4600, 93.2900, 395.1800],
[88.2300, 266.0800, 222.9800, 371.2500],
[176.9000, 283.8600, 208.4300, 292.3000],
[537.6300, 230.8300, 580.7100, 291.3300],
[421.1200, 230.4700, 580.6700, 350.9500],
[427.4200, 185.4300, 494.0200, 266.3500],
]
)
bboxes = datapoints.BoundingBox(boxes, format="XYXY", spatial_size=image.shape[-2:])
input = (image, bboxes)
transform(input)

@parametrize(
[
(
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ._transform import Transform # usort: skip

from ._augment import RandomErasing
from ._auto_augment import AugMix, AutoAugment, AutoAugmentDetection, RandAugment, TrivialAugmentWide
from ._auto_augment import _AutoAugmentDetection, AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Grayscale,
Expand Down
38 changes: 32 additions & 6 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _apply_transform(

class AutoAugment(_AutoAugmentBase):
r"""[BETA] AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/abs/1805.09501>`_.
.. v2betastatus:: AutoAugment transform
Expand Down Expand Up @@ -798,7 +798,33 @@ def _apply_image_or_video_and_bboxes_transform(
return image, bboxes


class AutoAugmentDetection(_AutoAugmentDetectionBase):
class _AutoAugmentDetection(_AutoAugmentDetectionBase):
r"""[BETA] AutoAugment data augmentation method for object detection task based on
`"Learning Data Augmentation Strategies for Object Detection" <https://arxiv.org/abs/1906.11172>`_.
.. v2betastatus:: AutoAugment transform
This transformation works on images, videos and bounding boxes only.
If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
policy (str, optional): The name of the AutoAugment policy to use. The available
options are `v0`, `v1`, `v2`, `v3` and `test`. `v0` is the policy used for
all of the results in the paper and was found to achieve the best results
on the COCO dataset. `v1`, `v2` and `v3` are additional good policies
found on the COCO dataset that have slight variation in what operations
were used during the search procedure along with how many operations are
applied in parallel to a single image (2 vs 3).
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""

_AUGMENTATION_SPACE = {
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
Expand All @@ -821,11 +847,11 @@ class AutoAugmentDetection(_AutoAugmentDetectionBase):
"Cutout": (lambda num_bins, height, width: torch.linspace(0, 100, num_bins + 1).round().int(), False),
"BBox_Cutout": (lambda num_bins, height, width: torch.linspace(0.0, 0.75, num_bins + 1), False),
"TranslateX": (
lambda num_bins, height, width: torch.linspace(0.0, 250.0, num_bins + 1),
lambda num_bins, height, width: torch.linspace(0.0, 250.0 / 331.0 * width, num_bins + 1),
True,
),
"TranslateY": (
lambda num_bins, height, width: torch.linspace(0.0, 250.0, num_bins + 1),
lambda num_bins, height, width: torch.linspace(0.0, 250.0 / 331.0 * height, num_bins + 1),
True,
),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins + 1), True),
Expand All @@ -835,11 +861,11 @@ class AutoAugmentDetection(_AutoAugmentDetectionBase):
"ShearX_Only_BBoxes": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins + 1), True),
"ShearY_Only_BBoxes": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins + 1), True),
"TranslateX_Only_BBoxes": (
lambda num_bins, height, width: torch.linspace(0.0, 120.0, num_bins + 1),
lambda num_bins, height, width: torch.linspace(0.0, 120.0 / 331.0 * width, num_bins + 1),
True,
),
"TranslateY_Only_BBoxes": (
lambda num_bins, height, width: torch.linspace(0.0, 120.0, num_bins + 1),
lambda num_bins, height, width: torch.linspace(0.0, 120.0 / 331.0 * height, num_bins + 1),
True,
),
"Flip_Only_BBoxes": (lambda num_bins, height, width: None, False),
Expand Down

0 comments on commit 3afeb30

Please sign in to comment.