Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AutoAugment for Detection #6609

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
0bd673b
init commit
ain-soph Sep 19, 2022
e29e444
a small update
ain-soph Sep 19, 2022
08f5d12
update
ain-soph Sep 19, 2022
779c1a6
fix type checking issues
ain-soph Sep 19, 2022
dcd3f18
Merge branch 'pytorch:main' into Implement-AutoAugment-for-Detection
ain-soph Oct 5, 2022
0e8df49
(In Progress) temp commit
ain-soph Oct 5, 2022
16f6823
finish the majority
ain-soph Oct 6, 2022
75740f2
Merge branch 'Implement-AutoAugment-for-Detection' into main
ain-soph Oct 24, 2022
a01572b
Merge pull request #1 from ain-soph/main
ain-soph Oct 24, 2022
0dfc95b
Merge branch 'pytorch:main' into Implement-AutoAugment-for-Detection
ain-soph Nov 4, 2022
e712bc5
update
ain-soph Nov 5, 2022
936c704
remove pyd files
ain-soph Nov 5, 2022
20fc9a5
fix type linting errors
ain-soph Nov 5, 2022
b7a8adf
add test
ain-soph Nov 5, 2022
faa1187
format test file
ain-soph Nov 5, 2022
7daaf98
Merge branch 'pytorch:main' into Implement-AutoAugment-for-Detection
ain-soph Dec 20, 2022
04ebdc7
update codes
ain-soph Dec 20, 2022
115136e
Merge branch 'pytorch:main' into Implement-AutoAugment-for-Detection
ain-soph Dec 20, 2022
1c8f0bc
fix test file
ain-soph Dec 20, 2022
d46d4c7
another test fix
ain-soph Dec 20, 2022
c813b33
Merge branch 'main' into Implement-AutoAugment-for-Detection
ain-soph Jan 15, 2023
df667d5
Merge branch 'main' of github.com:pytorch/vision into Implement-AutoA…
vfdev-5 Feb 3, 2023
60d5a9b
Few fixes and improvements
vfdev-5 Feb 3, 2023
da70616
Merge branch 'pytorch:main' into Implement-AutoAugment-for-Detection
ain-soph Feb 13, 2023
4e6e8a7
Merge branch
ain-soph Mar 2, 2023
8bbacdb
update
ain-soph Mar 2, 2023
87dd82d
add solarize_add and cutout, add test
ain-soph Mar 3, 2023
48b5539
fix cutout bbox case
ain-soph Mar 3, 2023
e38d9ed
fix solarize_add and cutout
ain-soph Mar 3, 2023
a38df46
remove 2 useless comments
ain-soph Mar 3, 2023
b15cb71
add some comments
ain-soph Mar 3, 2023
73316eb
fix type linting
ain-soph Mar 3, 2023
6020bbc
Merge branch 'pytorch:main' into Implement-AutoAugment-for-Detection
ain-soph Mar 9, 2023
2444128
Merge branch 'main' into Implement-AutoAugment-for-Detection
ain-soph Apr 6, 2023
ebc4341
update
ain-soph Apr 6, 2023
338d4ed
remove test policy
ain-soph Apr 6, 2023
8a8a569
Merge branch 'main' into Implement-AutoAugment-for-Detection
ain-soph Apr 19, 2023
bb1013a
Merge branch 'main' into Implement-AutoAugment-for-Detection
ain-soph May 8, 2023
29c9db5
remove todos
ain-soph May 8, 2023
417b744
Merge branch 'Implement-AutoAugment-for-Detection' of https://github.…
ain-soph May 8, 2023
4e978b1
Merge branch 'pytorch:main' into Implement-AutoAugment-for-Detection
ain-soph May 16, 2023
791d620
update
ain-soph May 16, 2023
3afeb30
AutoAugmentDetection -> _AutoAugmentDetection before we validate the …
vfdev-5 May 17, 2023
89f22ce
Merge branch 'main' into Implement-AutoAugment-for-Detection
ain-soph Jun 14, 2023
5425d95
Merge branch 'main' into Implement-AutoAugment-for-Detection
ain-soph Jul 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,28 @@ def auto_augment_adapter(transform, input, device):
return adapted_input


def auto_augment_detection_adapter(transform, input, device):
adapted_input = {}
image_or_video_found = False
bounding_box_found = False
for key, value in input.items():
if isinstance(value, datapoints.Mask):
# AA detection transforms don't support masks
continue
elif isinstance(value, datapoints.BoundingBox):
if bounding_box_found:
# AA detection transforms only support a single bounding box tensor
continue
bounding_box_found = True
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)):
if image_or_video_found:
# AA transforms only support a single image or video
continue
image_or_video_found = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to test different images and videos

adapted_input[key] = value
return adapted_input


def linear_transformation_adapter(transform, input, device):
flat_inputs = list(input.values())
c, h, w = query_chw(
Expand Down Expand Up @@ -119,6 +141,10 @@ class TestSmoke:
(transforms.AutoAugment(), auto_augment_adapter),
(transforms.RandAugment(), auto_augment_adapter),
(transforms.TrivialAugmentWide(), auto_augment_adapter),
(transforms._AutoAugmentDetection("v0"), auto_augment_detection_adapter),
(transforms._AutoAugmentDetection("v1"), auto_augment_detection_adapter),
(transforms._AutoAugmentDetection("v2"), auto_augment_detection_adapter),
(transforms._AutoAugmentDetection("v3"), auto_augment_detection_adapter),
(transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
(transforms.Grayscale(), None),
(transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
Expand Down Expand Up @@ -310,6 +336,37 @@ def test_common(self, transform, adapter, container_type, image_or_video, device
def test_auto_augment(self, transform, input):
transform(input)

@pytest.mark.parametrize(
"transform",
[
transforms._AutoAugmentDetection("v0"),
transforms._AutoAugmentDetection("v1"),
transforms._AutoAugmentDetection("v2"),
transforms._AutoAugmentDetection("v3"),
],
)
@pytest.mark.parametrize("seed", range(10))
def test_auto_augment_detection(self, transform, seed):
torch.manual_seed(seed)
image = datapoints.Image(torch.randint(0, 256, size=(3, 480, 640), dtype=torch.uint8))
boxes = torch.tensor(
[
[388.3100, 38.8300, 638.5600, 480.0000],
[82.6800, 314.7300, 195.4400, 445.7400],
[199.1000, 214.1700, 316.4100, 399.2800],
[159.8400, 183.6800, 216.8700, 273.1200],
[15.8000, 265.4600, 93.2900, 395.1800],
[88.2300, 266.0800, 222.9800, 371.2500],
[176.9000, 283.8600, 208.4300, 292.3000],
[537.6300, 230.8300, 580.7100, 291.3300],
[421.1200, 230.4700, 580.6700, 350.9500],
[427.4200, 185.4300, 494.0200, 266.3500],
]
)
bboxes = datapoints.BoundingBox(boxes, format="XYXY", spatial_size=image.shape[-2:])
input = (image, bboxes)
transform(input)

@parametrize(
[
(
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ._transform import Transform # usort: skip

from ._augment import RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._auto_augment import _AutoAugmentDetection, AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Grayscale,
Expand Down
Loading