Skip to content

Commit

Permalink
Add --backend support to detection refs
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jul 11, 2023
1 parent 08c9938 commit 6443e6a
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 79 deletions.
47 changes: 25 additions & 22 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def get_module(use_v2):


class ClassificationPresetTrain:
# Note: this transform assumes that the input to forward() are always PIL
# images, regardless of the backend parameter. We may change that in the
# future though, if we change the output type from the dataset.
def __init__(
self,
*,
Expand All @@ -30,42 +33,42 @@ def __init__(
backend="pil",
use_v2=False,
):
module = get_module(use_v2)
T = get_module(use_v2)

transforms = []
backend = backend.lower()
if backend == "tensor":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

transforms.append(module.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
if hflip_prob > 0:
transforms.append(module.RandomHorizontalFlip(hflip_prob))
transforms.append(T.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
transforms.append(module.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide":
transforms.append(module.TrivialAugmentWide(interpolation=interpolation))
transforms.append(T.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
transforms.append(module.AugMix(interpolation=interpolation, severity=augmix_severity))
transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity))
else:
aa_policy = module.AutoAugmentPolicy(auto_augment_policy)
transforms.append(module.AutoAugment(policy=aa_policy, interpolation=interpolation))
aa_policy = T.AutoAugmentPolicy(auto_augment_policy)
transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation))

if backend == "pil":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())

transforms.extend(
[
module.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
if random_erase_prob > 0:
transforms.append(module.RandomErasing(p=random_erase_prob))
transforms.append(T.RandomErasing(p=random_erase_prob))

self.transforms = module.Compose(transforms)
self.transforms = T.Compose(transforms)

def __call__(self, img):
return self.transforms(img)
Expand All @@ -83,28 +86,28 @@ def __init__(
backend="pil",
use_v2=False,
):
module = get_module(use_v2)
T = get_module(use_v2)
transforms = []
backend = backend.lower()
if backend == "tensor":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

transforms += [
module.Resize(resize_size, interpolation=interpolation, antialias=True),
module.CenterCrop(crop_size),
T.Resize(resize_size, interpolation=interpolation, antialias=True),
T.CenterCrop(crop_size),
]

if backend == "pil":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())

transforms += [
module.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]

self.transforms = module.Compose(transforms)
self.transforms = T.Compose(transforms)

def __call__(self, img):
return self.transforms(img)
142 changes: 89 additions & 53 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,109 @@
from collections import defaultdict

import torch
import transforms as T
import transforms as reference_transforms


def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.datapoints
import torchvision.transforms.v2

return torchvision.transforms.v2, torchvision.datapoints
else:
return reference_transforms, None


class DetectionPresetTrain:
def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
def __init__(
self,
*,
data_augmentation,
hflip_prob=0.5,
mean=(123.0, 117.0, 104.0),
backend="pil",
use_v2=False,
):

T, datapoints = get_modules(use_v2)

transforms = []
backend = backend.lower()
if backend == "datapoint":
transforms.append(T.ToImageTensor())
elif backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

if data_augmentation == "hflip":
self.transforms = T.Compose(
[
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
elif data_augmentation == "lsj":
self.transforms = T.Compose(
[
T.ScaleJitter(target_size=(1024, 1024)),
T.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [
T.ScaleJitter(target_size=(1024, 1024), antialias=True),
# TODO: FixedSizeCrop below doesn't work on tensors!
reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "multiscale":
self.transforms = T.Compose(
[
T.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [
T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "ssd":
self.transforms = T.Compose(
[
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=list(mean)),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
fill = defaultdict(lambda: mean, {datapoints.Mask: 0}) if use_v2 else list(mean)
transforms += [
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=fill),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "ssdlite":
self.transforms = T.Compose(
[
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
]
else:
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')

if backend == "pil":
# Note: we could just convert to pure tensors even in v2.
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]

transforms += [T.ConvertImageDtype(torch.float)]

if use_v2:
transforms += [
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY),
T.SanitizeBoundingBox(),
]

self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)


class DetectionPresetEval:
def __init__(self):
self.transforms = T.Compose(
[
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
def __init__(self, backend="pil", use_v2=False):
T, _ = get_modules(use_v2)
transforms = []
backend = backend.lower()
# Conversion may look a bit weird but the assumption of this transform is that the input is always a PIL image
# TODO: Is that still true when using v2, from the dataset???????
if backend == "pil":
# Note: we could just convert to pure tensors even in v2?
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
elif backend == "tensor":
transforms += [T.PILToTensor()]
elif backend == "datapoint":
transforms += [T.ToImageTensor()]
else:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

transforms += [T.ConvertImageDtype(torch.float)]
self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)
12 changes: 10 additions & 2 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ def get_dataset(name, image_set, transform, data_path):

def get_transform(train, args):
if train:
return presets.DetectionPresetTrain(data_augmentation=args.data_augmentation)
return presets.DetectionPresetTrain(
data_augmentation=args.data_augmentation, backend=args.backend, use_v2=args.use_v2
)
elif args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
return lambda img, target: (trans(img), target)
else:
return presets.DetectionPresetEval()
return presets.DetectionPresetEval(backend=args.backend, use_v2=args.use_v2)


def get_args_parser(add_help=True):
Expand Down Expand Up @@ -159,10 +161,16 @@ def get_args_parser(add_help=True):
help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.",
)

parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")

return parser


def main(args):
if args.backend.lower() == "datapoint" and not args.use_v2:
raise ValueError("Use --use-v2 if you want to use the datapoint backend.")

if args.output_dir:
utils.mkdir(args.output_dir)

Expand Down
9 changes: 7 additions & 2 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,13 @@ def __init__(
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias=True,
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = interpolation
self.antialias = antialias

def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
Expand All @@ -315,14 +317,17 @@ def forward(
new_width = int(orig_width * r)
new_height = int(orig_height * r)

image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation, antialias=self.antialias)

if target is not None:
target["boxes"][:, 0::2] *= new_width / orig_width
target["boxes"][:, 1::2] *= new_height / orig_height
if "masks" in target:
target["masks"] = F.resize(
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
target["masks"],
[new_height, new_width],
interpolation=InterpolationMode.NEAREST,
antialias=self.antialias,
)

return image, target
Expand Down

0 comments on commit 6443e6a

Please sign in to comment.