-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add --backend support to detection refs
- Loading branch information
1 parent
08c9938
commit 6443e6a
Showing
4 changed files
with
131 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,109 @@ | ||
from collections import defaultdict | ||
|
||
import torch | ||
import transforms as T | ||
import transforms as reference_transforms | ||
|
||
|
||
def get_modules(use_v2): | ||
# We need a protected import to avoid the V2 warning in case just V1 is used | ||
if use_v2: | ||
import torchvision.datapoints | ||
import torchvision.transforms.v2 | ||
|
||
return torchvision.transforms.v2, torchvision.datapoints | ||
else: | ||
return reference_transforms, None | ||
|
||
|
||
class DetectionPresetTrain: | ||
def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): | ||
def __init__( | ||
self, | ||
*, | ||
data_augmentation, | ||
hflip_prob=0.5, | ||
mean=(123.0, 117.0, 104.0), | ||
backend="pil", | ||
use_v2=False, | ||
): | ||
|
||
T, datapoints = get_modules(use_v2) | ||
|
||
transforms = [] | ||
backend = backend.lower() | ||
if backend == "datapoint": | ||
transforms.append(T.ToImageTensor()) | ||
elif backend == "tensor": | ||
transforms.append(T.PILToTensor()) | ||
elif backend != "pil": | ||
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") | ||
|
||
if data_augmentation == "hflip": | ||
self.transforms = T.Compose( | ||
[ | ||
T.RandomHorizontalFlip(p=hflip_prob), | ||
T.PILToTensor(), | ||
T.ConvertImageDtype(torch.float), | ||
] | ||
) | ||
transforms += [T.RandomHorizontalFlip(p=hflip_prob)] | ||
elif data_augmentation == "lsj": | ||
self.transforms = T.Compose( | ||
[ | ||
T.ScaleJitter(target_size=(1024, 1024)), | ||
T.FixedSizeCrop(size=(1024, 1024), fill=mean), | ||
T.RandomHorizontalFlip(p=hflip_prob), | ||
T.PILToTensor(), | ||
T.ConvertImageDtype(torch.float), | ||
] | ||
) | ||
transforms += [ | ||
T.ScaleJitter(target_size=(1024, 1024), antialias=True), | ||
# TODO: FixedSizeCrop below doesn't work on tensors! | ||
reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean), | ||
T.RandomHorizontalFlip(p=hflip_prob), | ||
] | ||
elif data_augmentation == "multiscale": | ||
self.transforms = T.Compose( | ||
[ | ||
T.RandomShortestSize( | ||
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 | ||
), | ||
T.RandomHorizontalFlip(p=hflip_prob), | ||
T.PILToTensor(), | ||
T.ConvertImageDtype(torch.float), | ||
] | ||
) | ||
transforms += [ | ||
T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333), | ||
T.RandomHorizontalFlip(p=hflip_prob), | ||
] | ||
elif data_augmentation == "ssd": | ||
self.transforms = T.Compose( | ||
[ | ||
T.RandomPhotometricDistort(), | ||
T.RandomZoomOut(fill=list(mean)), | ||
T.RandomIoUCrop(), | ||
T.RandomHorizontalFlip(p=hflip_prob), | ||
T.PILToTensor(), | ||
T.ConvertImageDtype(torch.float), | ||
] | ||
) | ||
fill = defaultdict(lambda: mean, {datapoints.Mask: 0}) if use_v2 else list(mean) | ||
transforms += [ | ||
T.RandomPhotometricDistort(), | ||
T.RandomZoomOut(fill=fill), | ||
T.RandomIoUCrop(), | ||
T.RandomHorizontalFlip(p=hflip_prob), | ||
] | ||
elif data_augmentation == "ssdlite": | ||
self.transforms = T.Compose( | ||
[ | ||
T.RandomIoUCrop(), | ||
T.RandomHorizontalFlip(p=hflip_prob), | ||
T.PILToTensor(), | ||
T.ConvertImageDtype(torch.float), | ||
] | ||
) | ||
transforms += [ | ||
T.RandomIoUCrop(), | ||
T.RandomHorizontalFlip(p=hflip_prob), | ||
] | ||
else: | ||
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') | ||
|
||
if backend == "pil": | ||
# Note: we could just convert to pure tensors even in v2. | ||
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] | ||
|
||
transforms += [T.ConvertImageDtype(torch.float)] | ||
|
||
if use_v2: | ||
transforms += [ | ||
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY), | ||
T.SanitizeBoundingBox(), | ||
] | ||
|
||
self.transforms = T.Compose(transforms) | ||
|
||
def __call__(self, img, target): | ||
return self.transforms(img, target) | ||
|
||
|
||
class DetectionPresetEval: | ||
def __init__(self): | ||
self.transforms = T.Compose( | ||
[ | ||
T.PILToTensor(), | ||
T.ConvertImageDtype(torch.float), | ||
] | ||
) | ||
def __init__(self, backend="pil", use_v2=False): | ||
T, _ = get_modules(use_v2) | ||
transforms = [] | ||
backend = backend.lower() | ||
# Conversion may look a bit weird but the assumption of this transform is that the input is always a PIL image | ||
# TODO: Is that still true when using v2, from the dataset??????? | ||
if backend == "pil": | ||
# Note: we could just convert to pure tensors even in v2? | ||
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] | ||
elif backend == "tensor": | ||
transforms += [T.PILToTensor()] | ||
elif backend == "datapoint": | ||
transforms += [T.ToImageTensor()] | ||
else: | ||
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") | ||
|
||
transforms += [T.ConvertImageDtype(torch.float)] | ||
self.transforms = T.Compose(transforms) | ||
|
||
def __call__(self, img, target): | ||
return self.transforms(img, target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters