diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index b9754e29b1c..5e88f5b9bb7 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -9,3 +9,5 @@ d367a01a18a3ae6bee13d8be3b63fd6a581ea46f 6ca9c76adb6daf2695d603ad623a9cf1c4f4806f # Fix unnecessary exploded black formatting (#7709) a335d916db0694770e8152f41e19195de3134523 +# Renaming: `BoundingBox` -> `BoundingBoxes` (#7778) +332bff937c6711666191880fab57fa2f23ae772e diff --git a/docs/source/datapoints.rst b/docs/source/datapoints.rst index 1cc62413e66..55d3cda4a8c 100644 --- a/docs/source/datapoints.rst +++ b/docs/source/datapoints.rst @@ -15,5 +15,5 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`. Image Video BoundingBoxFormat - BoundingBox + BoundingBoxes Mask diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 0d6961bbe79..73adb3cf3b5 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -206,8 +206,8 @@ Miscellaneous v2.RandomErasing Lambda v2.Lambda - v2.SanitizeBoundingBox - v2.ClampBoundingBox + v2.SanitizeBoundingBoxes + v2.ClampBoundingBoxes v2.UniformTemporalSubsample .. _conversion_transforms: @@ -234,7 +234,6 @@ Conversion v2.PILToTensor v2.ToImageTensor ConvertImageDtype - v2.ConvertDtype v2.ConvertImageDtype v2.ToDtype v2.ConvertBoundingBoxFormat @@ -262,6 +261,22 @@ The new transform can be used standalone or mixed-and-matched with existing tran AugMix v2.AugMix +CutMix - MixUp +-------------- + +CutMix and MixUp are special transforms that +are meant to be used on batches rather than on individual images, because they +are combining pairs of images together. These can be used after the dataloader +(once the samples are batched), or part of a collation function. See +:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples. + +.. autosummary:: + :toctree: generated/ + :template: class.rst + + v2.CutMix + v2.MixUp + .. _functional_transforms: Functional Transforms diff --git a/gallery/plot_cutmix_mixup.py b/gallery/plot_cutmix_mixup.py new file mode 100644 index 00000000000..d1c92a27812 --- /dev/null +++ b/gallery/plot_cutmix_mixup.py @@ -0,0 +1,152 @@ + +""" +=========================== +How to use CutMix and MixUp +=========================== + +:class:`~torchvision.transforms.v2.Cutmix` and +:class:`~torchvision.transforms.v2.Mixup` are popular augmentation strategies +that can improve classification accuracy. + +These transforms are slightly different from the rest of the Torchvision +transforms, because they expect +**batches** of samples as input, not individual images. In this example we'll +explain how to use them: after the ``DataLoader``, or as part of a collation +function. +""" + +# %% +import torch +import torchvision +from torchvision.datasets import FakeData + +# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that +# some APIs may slightly change in the future +torchvision.disable_beta_transforms_warning() + +from torchvision.transforms import v2 + + +NUM_CLASSES = 100 + +# %% +# Pre-processing pipeline +# ----------------------- +# +# We'll use a simple but typical image classification pipeline: + +preproc = v2.Compose([ + v2.PILToTensor(), + v2.RandomResizedCrop(size=(224, 224), antialias=True), + v2.RandomHorizontalFlip(p=0.5), + v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1] + v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet +]) + +dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc) + +img, label = dataset[0] +print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }") + +# %% +# +# One important thing to note is that neither CutMix nor MixUp are part of this +# pre-processing pipeline. We'll add them a bit later once we define the +# DataLoader. Just as a refresher, this is what the DataLoader and training loop +# would look like if we weren't using CutMix or MixUp: + +from torch.utils.data import DataLoader + +dataloader = DataLoader(dataset, batch_size=4, shuffle=True) + +for images, labels in dataloader: + print(f"{images.shape = }, {labels.shape = }") + print(labels.dtype) + # + break +# %% + +# %% +# Where to use MixUp and CutMix +# ----------------------------- +# +# After the DataLoader +# ^^^^^^^^^^^^^^^^^^^^ +# +# Now let's add CutMix and MixUp. The simplest way to do this right after the +# DataLoader: the Dataloader has already batched the images and labels for us, +# and this is exactly what these transforms expect as input: + +dataloader = DataLoader(dataset, batch_size=4, shuffle=True) + +cutmix = v2.Cutmix(num_classes=NUM_CLASSES) +mixup = v2.Mixup(num_classes=NUM_CLASSES) +cutmix_or_mixup = v2.RandomChoice([cutmix, mixup]) + +for images, labels in dataloader: + print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }") + images, labels = cutmix_or_mixup(images, labels) + print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }") + + # + break +# %% +# +# Note how the labels were also transformed: we went from a batched label of +# shape (batch_size,) to a tensor of shape (batch_size, num_classes). The +# transformed labels can still be passed as-is to a loss function like +# :func:`torch.nn.functional.cross_entropy`. +# +# As part of the collation function +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Passing the transforms after the DataLoader is the simplest way to use CutMix +# and MixUp, but one disadvantage is that it does not take advantage of the +# DataLoader multi-processing. For that, we can pass those transforms as part of +# the collation function (refer to the `PyTorch docs +# `_ to learn +# more about collation). + +from torch.utils.data import default_collate + + +def collate_fn(batch): + return cutmix_or_mixup(*default_collate(batch)) + + +dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn) + +for images, labels in dataloader: + print(f"{images.shape = }, {labels.shape = }") + # No need to call cutmix_or_mixup, it's already been called as part of the DataLoader! + # + break + +# %% +# Non-standard input format +# ------------------------- +# +# So far we've used a typical sample structure where we pass ``(images, +# labels)`` as inputs. MixUp and CutMix will magically work by default with most +# common sample structures: tuples where the second parameter is a tensor label, +# or dict with a "label[s]" key. Look at the documentation of the +# ``labels_getter`` parameter for more details. +# +# If your samples have a different structure, you can still use CutMix and MixUp +# by passing a callable to the ``labels_getter`` parameter. For example: + +batch = { + "imgs": torch.rand(4, 3, 224, 224), + "target": { + "classes": torch.randint(0, NUM_CLASSES, size=(4,)), + "some_other_key": "this is going to be passed-through" + } +} + + +def labels_getter(batch): + return batch["target"]["classes"] + + +out = v2.Cutmix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch) +print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }") diff --git a/gallery/plot_datapoints.py b/gallery/plot_datapoints.py index 5094de13a3e..fef282ae091 100644 --- a/gallery/plot_datapoints.py +++ b/gallery/plot_datapoints.py @@ -47,7 +47,7 @@ # # * :class:`~torchvision.datapoints.Image` # * :class:`~torchvision.datapoints.Video` -# * :class:`~torchvision.datapoints.BoundingBox` +# * :class:`~torchvision.datapoints.BoundingBoxes` # * :class:`~torchvision.datapoints.Mask` # # How do I construct a datapoint? @@ -76,11 +76,11 @@ ######################################################################################################################## # In general, the datapoints can also store additional metadata that complements the underlying tensor. For example, -# :class:`~torchvision.datapoints.BoundingBox` stores the coordinate format as well as the spatial size of the +# :class:`~torchvision.datapoints.BoundingBoxes` stores the coordinate format as well as the spatial size of the # corresponding image alongside the actual values: -bounding_box = datapoints.BoundingBox( - [17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:] +bounding_box = datapoints.BoundingBoxes( + [17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:] ) print(bounding_box) @@ -105,10 +105,10 @@ class PennFudanDataset(torch.utils.data.Dataset): def __getitem__(self, item): ... - target["boxes"] = datapoints.BoundingBox( + target["boxes"] = datapoints.BoundingBoxes( boxes, format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=F.get_spatial_size(img), + canvas_size=F.get_size(img), ) target["labels"] = labels target["masks"] = datapoints.Mask(masks) @@ -126,10 +126,10 @@ def __getitem__(self, item): class WrapPennFudanDataset: def __call__(self, img, target): - target["boxes"] = datapoints.BoundingBox( + target["boxes"] = datapoints.BoundingBoxes( target["boxes"], format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=F.get_spatial_size(img), + canvas_size=F.get_size(img), ) target["masks"] = datapoints.Mask(target["masks"]) return img, target @@ -147,7 +147,7 @@ def get_transform(train): ######################################################################################################################## # .. note:: # -# If both :class:`~torchvision.datapoints.BoundingBox`'es and :class:`~torchvision.datapoints.Mask`'s are included in +# If both :class:`~torchvision.datapoints.BoundingBoxes`'es and :class:`~torchvision.datapoints.Mask`'s are included in # the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or # at least not wrapping the obsolete parts, can lead to a significant performance boost. # diff --git a/gallery/plot_transforms_v2.py b/gallery/plot_transforms_v2.py index d1096bec1e7..88916ba44f9 100644 --- a/gallery/plot_transforms_v2.py +++ b/gallery/plot_transforms_v2.py @@ -29,8 +29,8 @@ def load_data(): masks = datapoints.Mask(merged_masks == labels.view(-1, 1, 1)) - bounding_boxes = datapoints.BoundingBox( - masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:] + bounding_boxes = datapoints.BoundingBoxes( + masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:] ) return path, image, bounding_boxes, masks, labels diff --git a/gallery/plot_transforms_v2_e2e.py b/gallery/plot_transforms_v2_e2e.py index 5d8d22dce83..981b1e58832 100644 --- a/gallery/plot_transforms_v2_e2e.py +++ b/gallery/plot_transforms_v2_e2e.py @@ -10,7 +10,6 @@ """ import pathlib -from collections import defaultdict import PIL.Image @@ -29,7 +28,7 @@ def show(sample): image, target = sample if isinstance(image, PIL.Image.Image): image = F.to_image_tensor(image) - image = F.convert_dtype(image, torch.uint8) + image = F.to_dtype(image, torch.uint8, scale=True) annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3) fig, ax = plt.subplots() @@ -99,20 +98,18 @@ def load_example_coco_detection_dataset(**kwargs): transform = transforms.Compose( [ transforms.RandomPhotometricDistort(), - transforms.RandomZoomOut( - fill=defaultdict(lambda: 0, {PIL.Image.Image: (123, 117, 104)}) - ), + transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}), transforms.RandomIoUCrop(), transforms.RandomHorizontalFlip(), transforms.ToImageTensor(), transforms.ConvertImageDtype(torch.float32), - transforms.SanitizeBoundingBox(), + transforms.SanitizeBoundingBoxes(), ] ) ######################################################################################################################## # .. note:: -# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBox` transform is a no-op in this example, but it +# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it # should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as # the corresponding labels and optionally masks. It is particularly critical to add it if # :class:`~torchvision.transforms.v2.RandomIoUCrop` was used. diff --git a/references/classification/presets.py b/references/classification/presets.py index 0f2c914be7e..9b53f0ccd5d 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -1,9 +1,23 @@ import torch -from torchvision.transforms import autoaugment, transforms from torchvision.transforms.functional import InterpolationMode +def get_module(use_v2): + # We need a protected import to avoid the V2 warning in case just V1 is used + if use_v2: + import torchvision.transforms.v2 + + return torchvision.transforms.v2 + else: + import torchvision.transforms + + return torchvision.transforms + + class ClassificationPresetTrain: + # Note: this transform assumes that the input to forward() are always PIL + # images, regardless of the backend parameter. We may change that in the + # future though, if we change the output type from the dataset. def __init__( self, *, @@ -17,41 +31,44 @@ def __init__( augmix_severity=3, random_erase_prob=0.0, backend="pil", + use_v2=False, ): - trans = [] + T = get_module(use_v2) + + transforms = [] backend = backend.lower() if backend == "tensor": - trans.append(transforms.PILToTensor()) + transforms.append(T.PILToTensor()) elif backend != "pil": raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") - trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) + transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) if hflip_prob > 0: - trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + transforms.append(T.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: if auto_augment_policy == "ra": - trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) + transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) elif auto_augment_policy == "ta_wide": - trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) + transforms.append(T.TrivialAugmentWide(interpolation=interpolation)) elif auto_augment_policy == "augmix": - trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity)) + transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity)) else: - aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) - trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) + aa_policy = T.AutoAugmentPolicy(auto_augment_policy) + transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation)) if backend == "pil": - trans.append(transforms.PILToTensor()) + transforms.append(T.PILToTensor()) - trans.extend( + transforms.extend( [ - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), + T.ConvertImageDtype(torch.float), + T.Normalize(mean=mean, std=std), ] ) if random_erase_prob > 0: - trans.append(transforms.RandomErasing(p=random_erase_prob)) + transforms.append(T.RandomErasing(p=random_erase_prob)) - self.transforms = transforms.Compose(trans) + self.transforms = T.Compose(transforms) def __call__(self, img): return self.transforms(img) @@ -67,28 +84,30 @@ def __init__( std=(0.229, 0.224, 0.225), interpolation=InterpolationMode.BILINEAR, backend="pil", + use_v2=False, ): - trans = [] + T = get_module(use_v2) + transforms = [] backend = backend.lower() if backend == "tensor": - trans.append(transforms.PILToTensor()) + transforms.append(T.PILToTensor()) elif backend != "pil": raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") - trans += [ - transforms.Resize(resize_size, interpolation=interpolation, antialias=True), - transforms.CenterCrop(crop_size), + transforms += [ + T.Resize(resize_size, interpolation=interpolation, antialias=True), + T.CenterCrop(crop_size), ] if backend == "pil": - trans.append(transforms.PILToTensor()) + transforms.append(T.PILToTensor()) - trans += [ - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), + transforms += [ + T.ConvertImageDtype(torch.float), + T.Normalize(mean=mean, std=std), ] - self.transforms = transforms.Compose(trans) + self.transforms = T.Compose(transforms) def __call__(self, img): return self.transforms(img) diff --git a/references/classification/train.py b/references/classification/train.py index 0c1a301453d..1bb0d86e9a5 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -8,12 +8,12 @@ import torch.utils.data import torchvision import torchvision.transforms -import transforms import utils from sampler import RASampler from torch import nn from torch.utils.data.dataloader import default_collate from torchvision.transforms.functional import InterpolationMode +from transforms import get_mixup_cutmix def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): @@ -145,6 +145,7 @@ def load_data(traindir, valdir, args): ra_magnitude=ra_magnitude, augmix_severity=augmix_severity, backend=args.backend, + use_v2=args.use_v2, ), ) if args.cache_dataset: @@ -172,6 +173,7 @@ def load_data(traindir, valdir, args): resize_size=val_resize_size, interpolation=interpolation, backend=args.backend, + use_v2=args.use_v2, ) dataset_test = torchvision.datasets.ImageFolder( @@ -216,18 +218,17 @@ def main(args): val_dir = os.path.join(args.data_path, "val") dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) - collate_fn = None num_classes = len(dataset.classes) - mixup_transforms = [] - if args.mixup_alpha > 0.0: - mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) - if args.cutmix_alpha > 0.0: - mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) - if mixup_transforms: - mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) + mixup_cutmix = get_mixup_cutmix( + mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_categories=num_classes, use_v2=args.use_v2 + ) + if mixup_cutmix is not None: def collate_fn(batch): - return mixupcutmix(*default_collate(batch)) + return mixup_cutmix(*default_collate(batch)) + + else: + collate_fn = default_collate data_loader = torch.utils.data.DataLoader( dataset, @@ -516,6 +517,7 @@ def get_args_parser(add_help=True): ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") + parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") return parser diff --git a/references/classification/transforms.py b/references/classification/transforms.py index 9a8ef7877d6..3d10388c36f 100644 --- a/references/classification/transforms.py +++ b/references/classification/transforms.py @@ -2,12 +2,35 @@ from typing import Tuple import torch +from presets import get_module from torch import Tensor from torchvision.transforms import functional as F -class RandomMixup(torch.nn.Module): - """Randomly apply Mixup to the provided batch and targets. +def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2): + transforms_module = get_module(use_v2) + + mixup_cutmix = [] + if mixup_alpha > 0: + mixup_cutmix.append( + transforms_module.MixUp(alpha=mixup_alpha, num_categories=num_categories) + if use_v2 + else RandomMixUp(num_classes=num_categories, p=1.0, alpha=mixup_alpha) + ) + if cutmix_alpha > 0: + mixup_cutmix.append( + transforms_module.CutMix(alpha=mixup_alpha, num_categories=num_categories) + if use_v2 + else RandomCutMix(num_classes=num_categories, p=1.0, alpha=mixup_alpha) + ) + if not mixup_cutmix: + return None + + return transforms_module.RandomChoice(mixup_cutmix) + + +class RandomMixUp(torch.nn.Module): + """Randomly apply MixUp to the provided batch and targets. The class implements the data augmentations as described in the paper `"mixup: Beyond Empirical Risk Minimization" `_. @@ -89,8 +112,8 @@ def __repr__(self) -> str: return s -class RandomCutmix(torch.nn.Module): - """Randomly apply Cutmix to the provided batch and targets. +class RandomCutMix(torch.nn.Module): + """Randomly apply CutMix to the provided batch and targets. The class implements the data augmentations as described in the paper `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" `_. diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 38c8279c35e..5269b45abc1 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -1,4 +1,3 @@ -import copy import os import torch @@ -9,24 +8,6 @@ from pycocotools.coco import COCO -class FilterAndRemapCocoCategories: - def __init__(self, categories, remap=True): - self.categories = categories - self.remap = remap - - def __call__(self, image, target): - anno = target["annotations"] - anno = [obj for obj in anno if obj["category_id"] in self.categories] - if not self.remap: - target["annotations"] = anno - return image, target - anno = copy.deepcopy(anno) - for obj in anno: - obj["category_id"] = self.categories.index(obj["category_id"]) - target["annotations"] = anno - return image, target - - def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: @@ -49,7 +30,6 @@ def __call__(self, image, target): w, h = image.size image_id = target["image_id"] - image_id = torch.tensor([image_id]) anno = target["annotations"] @@ -126,10 +106,6 @@ def _has_valid_annotation(anno): return True return False - if not isinstance(dataset, torchvision.datasets.CocoDetection): - raise TypeError( - f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}" - ) ids = [] for ds_idx, img_id in enumerate(dataset.ids): ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) @@ -196,6 +172,7 @@ def convert_to_coco_api(ds): def get_coco_api_from_dataset(dataset): + # FIXME: This is... awful? for _ in range(10): if isinstance(dataset, torchvision.datasets.CocoDetection): break @@ -220,7 +197,7 @@ def __getitem__(self, idx): return img, target -def get_coco(root, image_set, transforms, mode="instances"): +def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False): anno_file_template = "{}_{}2017.json" PATHS = { "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), @@ -228,17 +205,26 @@ def get_coco(root, image_set, transforms, mode="instances"): # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))) } - t = [ConvertCocoPolysToMask()] - - if transforms is not None: - t.append(transforms) - transforms = T.Compose(t) - img_folder, ann_file = PATHS[image_set] img_folder = os.path.join(root, img_folder) ann_file = os.path.join(root, ann_file) - dataset = CocoDetection(img_folder, ann_file, transforms=transforms) + if use_v2: + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) + target_keys = ["boxes", "labels", "image_id"] + if with_masks: + target_keys += ["masks"] + dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) + else: + # TODO: handle with_masks for V1? + t = [ConvertCocoPolysToMask()] + if transforms is not None: + t.append(transforms) + transforms = T.Compose(t) + + dataset = CocoDetection(img_folder, ann_file, transforms=transforms) if image_set == "train": dataset = _coco_remove_images_without_annotations(dataset) @@ -246,7 +232,3 @@ def get_coco(root, image_set, transforms, mode="instances"): # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) return dataset - - -def get_coco_kp(root, image_set, transforms): - return get_coco(root, image_set, transforms, mode="person_keypoints") diff --git a/references/detection/engine.py b/references/detection/engine.py index 0e5d55f189d..0e9bfffdf8a 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -26,7 +26,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc for images, targets in metric_logger.log_every(data_loader, print_freq, header): images = list(image.to(device) for image in images) - targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] with torch.cuda.amp.autocast(enabled=scaler is not None): loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) @@ -97,7 +97,7 @@ def evaluate(model, data_loader, device): outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] model_time = time.time() - model_time - res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} + res = {target["image_id"]: output for target, output in zip(targets, outputs)} evaluator_time = time.time() coco_evaluator.update(res) evaluator_time = time.time() - evaluator_time diff --git a/references/detection/presets.py b/references/detection/presets.py index 779f3f218ca..098ec85e690 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -1,73 +1,109 @@ +from collections import defaultdict + import torch -import transforms as T +import transforms as reference_transforms + + +def get_modules(use_v2): + # We need a protected import to avoid the V2 warning in case just V1 is used + if use_v2: + import torchvision.datapoints + import torchvision.transforms.v2 + + return torchvision.transforms.v2, torchvision.datapoints + else: + return reference_transforms, None class DetectionPresetTrain: - def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): + # Note: this transform assumes that the input to forward() are always PIL + # images, regardless of the backend parameter. + def __init__( + self, + *, + data_augmentation, + hflip_prob=0.5, + mean=(123.0, 117.0, 104.0), + backend="pil", + use_v2=False, + ): + + T, datapoints = get_modules(use_v2) + + transforms = [] + backend = backend.lower() + if backend == "datapoint": + transforms.append(T.ToImageTensor()) + elif backend == "tensor": + transforms.append(T.PILToTensor()) + elif backend != "pil": + raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") + if data_augmentation == "hflip": - self.transforms = T.Compose( - [ - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms += [T.RandomHorizontalFlip(p=hflip_prob)] elif data_augmentation == "lsj": - self.transforms = T.Compose( - [ - T.ScaleJitter(target_size=(1024, 1024)), - T.FixedSizeCrop(size=(1024, 1024), fill=mean), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms += [ + T.ScaleJitter(target_size=(1024, 1024), antialias=True), + # TODO: FixedSizeCrop below doesn't work on tensors! + reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean), + T.RandomHorizontalFlip(p=hflip_prob), + ] elif data_augmentation == "multiscale": - self.transforms = T.Compose( - [ - T.RandomShortestSize( - min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 - ), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms += [ + T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333), + T.RandomHorizontalFlip(p=hflip_prob), + ] elif data_augmentation == "ssd": - self.transforms = T.Compose( - [ - T.RandomPhotometricDistort(), - T.RandomZoomOut(fill=list(mean)), - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + fill = defaultdict(lambda: mean, {datapoints.Mask: 0}) if use_v2 else list(mean) + transforms += [ + T.RandomPhotometricDistort(), + T.RandomZoomOut(fill=fill), + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + ] elif data_augmentation == "ssdlite": - self.transforms = T.Compose( - [ - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + transforms += [ + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + ] else: raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') + if backend == "pil": + # Note: we could just convert to pure tensors even in v2. + transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] + + transforms += [T.ConvertImageDtype(torch.float)] + + if use_v2: + transforms += [ + T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY), + T.SanitizeBoundingBoxes(), + ] + + self.transforms = T.Compose(transforms) + def __call__(self, img, target): return self.transforms(img, target) class DetectionPresetEval: - def __init__(self): - self.transforms = T.Compose( - [ - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ] - ) + def __init__(self, backend="pil", use_v2=False): + T, _ = get_modules(use_v2) + transforms = [] + backend = backend.lower() + if backend == "pil": + # Note: we could just convert to pure tensors even in v2? + transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] + elif backend == "tensor": + transforms += [T.PILToTensor()] + elif backend == "datapoint": + transforms += [T.ToImageTensor()] + else: + raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") + + transforms += [T.ConvertImageDtype(torch.float)] + self.transforms = T.Compose(transforms) def __call__(self, img, target): return self.transforms(img, target) diff --git a/references/detection/train.py b/references/detection/train.py index dea483c5f75..892ffbbbc1c 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -28,7 +28,7 @@ import torchvision.models.detection import torchvision.models.detection.mask_rcnn import utils -from coco_utils import get_coco, get_coco_kp +from coco_utils import get_coco from engine import evaluate, train_one_epoch from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler from torchvision.transforms import InterpolationMode @@ -40,23 +40,32 @@ def copypaste_collate_fn(batch): return copypaste(*utils.collate_fn(batch)) -def get_dataset(name, image_set, transform, data_path): - paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} - p, ds_fn, num_classes = paths[name] - - ds = ds_fn(p, image_set=image_set, transforms=transform) +def get_dataset(is_train, args): + image_set = "train" if is_train else "val" + num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset] + with_masks = "mask" in args.model + ds = get_coco( + root=args.data_path, + image_set=image_set, + transforms=get_transform(is_train, args), + mode=mode, + use_v2=args.use_v2, + with_masks=with_masks, + ) return ds, num_classes -def get_transform(train, args): - if train: - return presets.DetectionPresetTrain(data_augmentation=args.data_augmentation) +def get_transform(is_train, args): + if is_train: + return presets.DetectionPresetTrain( + data_augmentation=args.data_augmentation, backend=args.backend, use_v2=args.use_v2 + ) elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() return lambda img, target: (trans(img), target) else: - return presets.DetectionPresetEval() + return presets.DetectionPresetEval(backend=args.backend, use_v2=args.use_v2) def get_args_parser(add_help=True): @@ -65,7 +74,12 @@ def get_args_parser(add_help=True): parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help) parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path") - parser.add_argument("--dataset", default="coco", type=str, help="dataset name") + parser.add_argument( + "--dataset", + default="coco", + type=str, + help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection", + ) parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument( @@ -159,10 +173,22 @@ def get_args_parser(add_help=True): help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.", ) + parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") + parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") + return parser def main(args): + if args.backend.lower() == "datapoint" and not args.use_v2: + raise ValueError("Use --use-v2 if you want to use the datapoint backend.") + if args.dataset not in ("coco", "coco_kp"): + raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}") + if "keypoint" in args.model and args.dataset != "coco_kp": + raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp") + if args.dataset == "coco_kp" and args.use_v2: + raise ValueError("KeyPoint detection doesn't support V2 transforms yet") + if args.output_dir: utils.mkdir(args.output_dir) @@ -177,8 +203,8 @@ def main(args): # Data loading code print("Loading data") - dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path) - dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path) + dataset, num_classes = get_dataset(is_train=True, args=args) + dataset_test, _ = get_dataset(is_train=False, args=args) print("Creating data loaders") if args.distributed: diff --git a/references/detection/transforms.py b/references/detection/transforms.py index d26bf6eac85..65cf4e83592 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -293,11 +293,13 @@ def __init__( target_size: Tuple[int, int], scale_range: Tuple[float, float] = (0.1, 2.0), interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias=True, ): super().__init__() self.target_size = target_size self.scale_range = scale_range self.interpolation = interpolation + self.antialias = antialias def forward( self, image: Tensor, target: Optional[Dict[str, Tensor]] = None @@ -315,14 +317,17 @@ def forward( new_width = int(orig_width * r) new_height = int(orig_height * r) - image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) + image = F.resize(image, [new_height, new_width], interpolation=self.interpolation, antialias=self.antialias) if target is not None: target["boxes"][:, 0::2] *= new_width / orig_width target["boxes"][:, 1::2] *= new_height / orig_height if "masks" in target: target["masks"] = F.resize( - target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST + target["masks"], + [new_height, new_width], + interpolation=InterpolationMode.NEAREST, + antialias=self.antialias, ) return image, target diff --git a/references/segmentation/coco_utils.py b/references/segmentation/coco_utils.py index e02434012f1..6a15dbefb52 100644 --- a/references/segmentation/coco_utils.py +++ b/references/segmentation/coco_utils.py @@ -68,11 +68,6 @@ def _has_valid_annotation(anno): # if more than 1k pixels occupied in the image return sum(obj["area"] for obj in anno) > 1000 - if not isinstance(dataset, torchvision.datasets.CocoDetection): - raise TypeError( - f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}" - ) - ids = [] for ds_idx, img_id in enumerate(dataset.ids): ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) @@ -86,7 +81,7 @@ def _has_valid_annotation(anno): return dataset -def get_coco(root, image_set, transforms): +def get_coco(root, image_set, transforms, use_v2=False): PATHS = { "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), @@ -94,13 +89,24 @@ def get_coco(root, image_set, transforms): } CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] - transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms]) - img_folder, ann_file = PATHS[image_set] img_folder = os.path.join(root, img_folder) ann_file = os.path.join(root, ann_file) - dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) + # The 2 "Compose" below achieve the same thing: converting coco detection + # samples into segmentation-compatible samples. They just do it with + # slightly different implementations. We could refactor and unify, but + # keeping them separate helps keeping the v2 version clean + if use_v2: + import v2_extras + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms]) + dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) + dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"}) + else: + transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms]) + dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) if image_set == "train": dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST) diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index ed02ae660e4..e62fd5ae301 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -1,39 +1,104 @@ import torch -import transforms as T + + +def get_modules(use_v2): + # We need a protected import to avoid the V2 warning in case just V1 is used + if use_v2: + import torchvision.datapoints + import torchvision.transforms.v2 + import v2_extras + + return torchvision.transforms.v2, torchvision.datapoints, v2_extras + else: + import transforms + + return transforms, None, None class SegmentationPresetTrain: - def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - min_size = int(0.5 * base_size) - max_size = int(2.0 * base_size) + def __init__( + self, + *, + base_size, + crop_size, + hflip_prob=0.5, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + backend="pil", + use_v2=False, + ): + T, datapoints, v2_extras = get_modules(use_v2) + + transforms = [] + backend = backend.lower() + if backend == "datapoint": + transforms.append(T.ToImageTensor()) + elif backend == "tensor": + transforms.append(T.PILToTensor()) + elif backend != "pil": + raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") + + transforms += [T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size))] - trans = [T.RandomResize(min_size, max_size)] if hflip_prob > 0: - trans.append(T.RandomHorizontalFlip(hflip_prob)) - trans.extend( - [ - T.RandomCrop(crop_size), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - T.Normalize(mean=mean, std=std), + transforms += [T.RandomHorizontalFlip(hflip_prob)] + + if use_v2: + # We need a custom pad transform here, since the padding we want to perform here is fundamentally + # different from the padding in `RandomCrop` if `pad_if_needed=True`. + transforms += [v2_extras.PadIfSmaller(crop_size, fill={datapoints.Mask: 255, "others": 0})] + + transforms += [T.RandomCrop(crop_size)] + + if backend == "pil": + transforms += [T.PILToTensor()] + + if use_v2: + img_type = datapoints.Image if backend == "datapoint" else torch.Tensor + transforms += [ + T.ToDtype(dtype={img_type: torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True) ] - ) - self.transforms = T.Compose(trans) + else: + # No need to explicitly convert masks as they're magically int64 already + transforms += [T.ConvertImageDtype(torch.float)] + + transforms += [T.Normalize(mean=mean, std=std)] + + self.transforms = T.Compose(transforms) def __call__(self, img, target): return self.transforms(img, target) class SegmentationPresetEval: - def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - self.transforms = T.Compose( - [ - T.RandomResize(base_size, base_size), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - T.Normalize(mean=mean, std=std), - ] - ) + def __init__( + self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil", use_v2=False + ): + T, _, _ = get_modules(use_v2) + + transforms = [] + backend = backend.lower() + if backend == "tensor": + transforms += [T.PILToTensor()] + elif backend == "datapoint": + transforms += [T.ToImageTensor()] + elif backend != "pil": + raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") + + if use_v2: + transforms += [T.Resize(size=(base_size, base_size))] + else: + transforms += [T.RandomResize(min_size=base_size, max_size=base_size)] + + if backend == "pil": + # Note: we could just convert to pure tensors even in v2? + transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] + + transforms += [ + T.ConvertImageDtype(torch.float), + T.Normalize(mean=mean, std=std), + ] + self.transforms = T.Compose(transforms) def __call__(self, img, target): return self.transforms(img, target) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 1aa72a9fe38..7ca4bd1c592 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -14,24 +14,30 @@ from torchvision.transforms import functional as F, InterpolationMode -def get_dataset(dir_path, name, image_set, transform): +def get_dataset(args, is_train): def sbd(*args, **kwargs): + kwargs.pop("use_v2") return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) + def voc(*args, **kwargs): + kwargs.pop("use_v2") + return torchvision.datasets.VOCSegmentation(*args, **kwargs) + paths = { - "voc": (dir_path, torchvision.datasets.VOCSegmentation, 21), - "voc_aug": (dir_path, sbd, 21), - "coco": (dir_path, get_coco, 21), + "voc": (args.data_path, voc, 21), + "voc_aug": (args.data_path, sbd, 21), + "coco": (args.data_path, get_coco, 21), } - p, ds_fn, num_classes = paths[name] + p, ds_fn, num_classes = paths[args.dataset] - ds = ds_fn(p, image_set=image_set, transforms=transform) + image_set = "train" if is_train else "val" + ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2) return ds, num_classes -def get_transform(train, args): - if train: - return presets.SegmentationPresetTrain(base_size=520, crop_size=480) +def get_transform(is_train, args): + if is_train: + return presets.SegmentationPresetTrain(base_size=520, crop_size=480, backend=args.backend, use_v2=args.use_v2) elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() @@ -44,7 +50,7 @@ def preprocessing(img, target): return preprocessing else: - return presets.SegmentationPresetEval(base_size=520) + return presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2) def criterion(inputs, target): @@ -120,6 +126,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi def main(args): + if args.backend.lower() != "pil" and not args.use_v2: + # TODO: Support tensor backend in V1? + raise ValueError("Use --use-v2 if you want to use the datapoint or tensor backend.") + if args.use_v2 and args.dataset != "coco": + raise ValueError("v2 is only support supported for coco dataset for now.") + if args.output_dir: utils.mkdir(args.output_dir) @@ -134,8 +146,8 @@ def main(args): else: torch.backends.cudnn.benchmark = True - dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args)) - dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args)) + dataset, num_classes = get_dataset(args, is_train=True) + dataset_test, _ = get_dataset(args, is_train=False) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) @@ -307,6 +319,8 @@ def get_args_parser(add_help=True): # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") + parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") return parser diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 518048db2fa..2b3e79b1461 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -35,7 +35,7 @@ def __init__(self, min_size, max_size=None): def __call__(self, image, target): size = random.randint(self.min_size, self.max_size) - image = F.resize(image, size) + image = F.resize(image, size, antialias=True) target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST) return image, target diff --git a/references/segmentation/utils.py b/references/segmentation/utils.py index 4ea24db83ed..cb200f23d76 100644 --- a/references/segmentation/utils.py +++ b/references/segmentation/utils.py @@ -267,9 +267,9 @@ def init_distributed_mode(args): args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) args.gpu = int(os.environ["LOCAL_RANK"]) - elif "SLURM_PROCID" in os.environ: - args.rank = int(os.environ["SLURM_PROCID"]) - args.gpu = args.rank % torch.cuda.device_count() + # elif "SLURM_PROCID" in os.environ: + # args.rank = int(os.environ["SLURM_PROCID"]) + # args.gpu = args.rank % torch.cuda.device_count() elif hasattr(args, "rank"): pass else: diff --git a/references/segmentation/v2_extras.py b/references/segmentation/v2_extras.py new file mode 100644 index 00000000000..f21799e86c8 --- /dev/null +++ b/references/segmentation/v2_extras.py @@ -0,0 +1,83 @@ +"""This file only exists to be lazy-imported and avoid V2-related import warnings when just using V1.""" +import torch +from torchvision import datapoints +from torchvision.transforms import v2 + + +class PadIfSmaller(v2.Transform): + def __init__(self, size, fill=0): + super().__init__() + self.size = size + self.fill = v2._utils._setup_fill_arg(fill) + + def _get_params(self, sample): + _, height, width = v2.utils.query_chw(sample) + padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] + needs_padding = any(padding) + return dict(padding=padding, needs_padding=needs_padding) + + def _transform(self, inpt, params): + if not params["needs_padding"]: + return inpt + + fill = v2._utils._get_fill(self.fill, type(inpt)) + fill = v2._utils._convert_fill_arg(fill) + + return v2.functional.pad(inpt, padding=params["padding"], fill=fill) + + +class CocoDetectionToVOCSegmentation(v2.Transform): + """Turn samples from datasets.CocoDetection into the same format as VOCSegmentation. + + This is achieved in two steps: + + 1. COCO differentiates between 91 categories while VOC only supports 21, including background for both. Fortunately, + the COCO categories are a superset of the VOC ones and thus can be mapped. Instances of the 70 categories not + present in VOC are dropped and replaced by background. + 2. COCO only offers detection masks, i.e. a (N, H, W) bool-ish tensor, where the truthy values in each individual + mask denote the instance. However, a segmentation mask is a (H, W) integer tensor (typically torch.uint8), where + the value of each pixel denotes the category it belongs to. The detection masks are merged into one segmentation + mask while pixels that belong to multiple detection masks are marked as invalid. + """ + + COCO_TO_VOC_LABEL_MAP = dict( + zip( + [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72], + range(21), + ) + ) + INVALID_VALUE = 255 + + def _coco_detection_masks_to_voc_segmentation_mask(self, target): + if "masks" not in target: + return None + + instance_masks, instance_labels_coco = target["masks"], target["labels"] + + valid_labels_voc = [ + (idx, label_voc) + for idx, label_coco in enumerate(instance_labels_coco.tolist()) + if (label_voc := self.COCO_TO_VOC_LABEL_MAP.get(label_coco)) is not None + ] + + if not valid_labels_voc: + return None + + valid_voc_category_idcs, instance_labels_voc = zip(*valid_labels_voc) + + instance_masks = instance_masks[list(valid_voc_category_idcs)].to(torch.uint8) + instance_labels_voc = torch.tensor(instance_labels_voc, dtype=torch.uint8) + + # Calling `.max()` on the stacked detection masks works fine to separate background from foreground as long as + # there is at most a single instance per pixel. Overlapping instances will be filtered out in the next step. + segmentation_mask, _ = (instance_masks * instance_labels_voc.reshape(-1, 1, 1)).max(dim=0) + segmentation_mask[instance_masks.sum(dim=0) > 1] = self.INVALID_VALUE + + return segmentation_mask + + def forward(self, image, target): + segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target) + if segmentation_mask is None: + segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8) + + return image, datapoints.Mask(segmentation_mask) diff --git a/test/common_utils.py b/test/common_utils.py index 32f36cf5a21..3f8a12e161c 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -27,7 +27,7 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import datapoints, io from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_tensor +from torchvision.transforms.v2.functional import to_dtype_image_tensor, to_image_pil, to_image_tensor IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) @@ -410,6 +410,9 @@ def load(self, device="cpu"): ) +# new v2 default +DEFAULT_SIZE = (17, 11) +# old v2 defaults DEFAULT_SQUARE_SPATIAL_SIZE = 15 DEFAULT_LANDSCAPE_SPATIAL_SIZE = (7, 33) DEFAULT_PORTRAIT_SPATIAL_SIZE = (31, 9) @@ -417,13 +420,12 @@ def load(self, device="cpu"): DEFAULT_LANDSCAPE_SPATIAL_SIZE, DEFAULT_PORTRAIT_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE, - "random", ) -def _parse_spatial_size(size, *, name="size"): +def _parse_canvas_size(size, *, name="size"): if size == "random": - return tuple(torch.randint(15, 33, (2,)).tolist()) + raise ValueError("This should never happen") elif isinstance(size, int) and size > 0: return (size, size) elif ( @@ -476,12 +478,13 @@ def load(self, device): @dataclasses.dataclass class ImageLoader(TensorLoader): - spatial_size: Tuple[int, int] = dataclasses.field(init=False) + canvas_size: Tuple[int, int] = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False) memory_format: torch.memory_format = torch.contiguous_format + canvas_size: Tuple[int, int] = dataclasses.field(init=False) def __post_init__(self): - self.spatial_size = self.shape[-2:] + self.canvas_size = self.canvas_size = self.shape[-2:] self.num_channels = self.shape[-3] def load(self, device): @@ -503,8 +506,41 @@ def get_num_channels(color_space): return num_channels +def make_image( + size=DEFAULT_SIZE, + *, + color_space="RGB", + batch_dims=(), + dtype=None, + device="cpu", + memory_format=torch.contiguous_format, +): + dtype = dtype or torch.uint8 + max_value = get_max_value(dtype) + data = torch.testing.make_tensor( + (*batch_dims, get_num_channels(color_space), *size), + low=0, + high=max_value, + dtype=dtype, + device=device, + memory_format=memory_format, + ) + if color_space in {"GRAY_ALPHA", "RGBA"}: + data[..., -1, :, :] = max_value + + return datapoints.Image(data) + + +def make_image_tensor(*args, **kwargs): + return make_image(*args, **kwargs).as_subclass(torch.Tensor) + + +def make_image_pil(*args, **kwargs): + return to_image_pil(make_image(*args, **kwargs)) + + def make_image_loader( - size="random", + size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, color_space="RGB", extra_dims=(), @@ -512,24 +548,25 @@ def make_image_loader( constant_alpha=True, memory_format=torch.contiguous_format, ): - size = _parse_spatial_size(size) + if not constant_alpha: + raise ValueError("This should never happen") + size = _parse_canvas_size(size) num_channels = get_num_channels(color_space) def fn(shape, dtype, device, memory_format): - max_value = get_max_value(dtype) - data = torch.testing.make_tensor( - shape, low=0, high=max_value, dtype=dtype, device=device, memory_format=memory_format + *batch_dims, _, height, width = shape + return make_image( + (height, width), + color_space=color_space, + batch_dims=batch_dims, + dtype=dtype, + device=device, + memory_format=memory_format, ) - if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha: - data[..., -1, :, :] = max_value - return datapoints.Image(data) return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, memory_format=memory_format) -make_image = from_loader(make_image_loader) - - def make_image_loaders( *, sizes=DEFAULT_SPATIAL_SIZES, @@ -551,9 +588,9 @@ def make_image_loaders( def make_image_loader_for_interpolation( - size="random", *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format + size=(233, 147), *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format ): - size = _parse_spatial_size(size) + size = _parse_canvas_size(size) num_channels = get_num_channels(color_space) def fn(shape, dtype, device, memory_format): @@ -577,7 +614,7 @@ def fn(shape, dtype, device, memory_format): image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True) else: image_tensor = image_tensor.to(device=device) - image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype) + image_tensor = to_dtype_image_tensor(image_tensor, dtype=dtype, scale=True) return datapoints.Image(image_tensor) @@ -595,85 +632,88 @@ def make_image_loaders_for_interpolation( @dataclasses.dataclass -class BoundingBoxLoader(TensorLoader): +class BoundingBoxesLoader(TensorLoader): format: datapoints.BoundingBoxFormat spatial_size: Tuple[int, int] + canvas_size: Tuple[int, int] = dataclasses.field(init=False) + + def __post_init__(self): + self.canvas_size = self.spatial_size + + +def make_bounding_box( + canvas_size=DEFAULT_SIZE, + *, + format=datapoints.BoundingBoxFormat.XYXY, + batch_dims=(), + dtype=None, + device="cpu", +): + def sample_position(values, max_value): + # We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high. + # However, if we have batch_dims, we need tensors as limits. + return torch.stack([torch.randint(max_value - v, ()) for v in values.flatten().tolist()]).reshape(values.shape) + + if isinstance(format, str): + format = datapoints.BoundingBoxFormat[format] + + dtype = dtype or torch.float32 + + if any(dim == 0 for dim in batch_dims): + return datapoints.BoundingBoxes( + torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, canvas_size=canvas_size + ) + h, w = [torch.randint(1, c, batch_dims) for c in canvas_size] + y = sample_position(h, canvas_size[0]) + x = sample_position(w, canvas_size[1]) + + if format is datapoints.BoundingBoxFormat.XYWH: + parts = (x, y, w, h) + elif format is datapoints.BoundingBoxFormat.XYXY: + x1, y1 = x, y + x2 = x1 + w + y2 = y1 + h + parts = (x1, y1, x2, y2) + elif format is datapoints.BoundingBoxFormat.CXCYWH: + cx = x + w / 2 + cy = y + h / 2 + parts = (cx, cy, w, h) + else: + raise ValueError(f"Format {format} is not supported") -def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): - low, high = torch.broadcast_tensors( - *[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))] + return datapoints.BoundingBoxes( + torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size ) - return torch.stack( - [ - torch.randint(low_scalar, high_scalar, (), **kwargs) - for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist()) - ] - ).reshape(low.shape) -def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dtype=torch.float32): +def make_bounding_box_loader(*, extra_dims=(), format, canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32): if isinstance(format, str): format = datapoints.BoundingBoxFormat[format] - if format not in { - datapoints.BoundingBoxFormat.XYXY, - datapoints.BoundingBoxFormat.XYWH, - datapoints.BoundingBoxFormat.CXCYWH, - }: - raise pytest.UsageError(f"Can't make bounding box in format {format}") - spatial_size = _parse_spatial_size(spatial_size, name="spatial_size") + canvas_size = _parse_canvas_size(canvas_size, name="canvas_size") def fn(shape, dtype, device): - *extra_dims, num_coordinates = shape + *batch_dims, num_coordinates = shape if num_coordinates != 4: raise pytest.UsageError() - if any(dim == 0 for dim in extra_dims): - return datapoints.BoundingBox( - torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size - ) - - height, width = spatial_size - - if format == datapoints.BoundingBoxFormat.XYXY: - x1 = torch.randint(0, width // 2, extra_dims) - y1 = torch.randint(0, height // 2, extra_dims) - x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1 - y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1 - parts = (x1, y1, x2, y2) - elif format == datapoints.BoundingBoxFormat.XYWH: - x = torch.randint(0, width // 2, extra_dims) - y = torch.randint(0, height // 2, extra_dims) - w = randint_with_tensor_bounds(1, width - x) - h = randint_with_tensor_bounds(1, height - y) - parts = (x, y, w, h) - else: # format == features.BoundingBoxFormat.CXCYWH: - cx = torch.randint(1, width - 1, extra_dims) - cy = torch.randint(1, height - 1, extra_dims) - w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) - h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1) - parts = (cx, cy, w, h) - - return datapoints.BoundingBox( - torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size + return make_bounding_box( + format=format, canvas_size=canvas_size, batch_dims=batch_dims, dtype=dtype, device=device ) - return BoundingBoxLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size) - - -make_bounding_box = from_loader(make_bounding_box_loader) + return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=canvas_size) def make_bounding_box_loaders( *, extra_dims=DEFAULT_EXTRA_DIMS, formats=tuple(datapoints.BoundingBoxFormat), - spatial_size="random", + canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtypes=(torch.float32, torch.float64, torch.int64), ): for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes): - yield make_bounding_box_loader(**params, spatial_size=spatial_size) + yield make_bounding_box_loader(**params, canvas_size=canvas_size) make_bounding_boxes = from_loaders(make_bounding_box_loaders) @@ -683,24 +723,35 @@ class MaskLoader(TensorLoader): pass -def make_detection_mask_loader(size="random", *, num_objects="random", extra_dims=(), dtype=torch.uint8): +def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtype=None, device="cpu"): + """Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks""" + return datapoints.Mask( + torch.testing.make_tensor( + (*batch_dims, num_objects, *size), + low=0, + high=2, + dtype=dtype or torch.bool, + device=device, + ) + ) + + +def make_detection_mask_loader(size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_objects=5, extra_dims=(), dtype=torch.uint8): # This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects - size = _parse_spatial_size(size) - num_objects = int(torch.randint(1, 11, ())) if num_objects == "random" else num_objects + size = _parse_canvas_size(size) def fn(shape, dtype, device): - data = torch.testing.make_tensor(shape, low=0, high=2, dtype=dtype, device=device) - return datapoints.Mask(data) + *batch_dims, num_objects, height, width = shape + return make_detection_mask( + (height, width), num_objects=num_objects, batch_dims=batch_dims, dtype=dtype, device=device + ) return MaskLoader(fn, shape=(*extra_dims, num_objects, *size), dtype=dtype) -make_detection_mask = from_loader(make_detection_mask_loader) - - def make_detection_mask_loaders( sizes=DEFAULT_SPATIAL_SIZES, - num_objects=(1, 0, "random"), + num_objects=(1, 0, 5), extra_dims=DEFAULT_EXTRA_DIMS, dtypes=(torch.uint8,), ): @@ -711,25 +762,38 @@ def make_detection_mask_loaders( make_detection_masks = from_loaders(make_detection_mask_loaders) -def make_segmentation_mask_loader(size="random", *, num_categories="random", extra_dims=(), dtype=torch.uint8): - # This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values - size = _parse_spatial_size(size) - num_categories = int(torch.randint(1, 11, ())) if num_categories == "random" else num_categories +def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"): + """Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value""" + return datapoints.Mask( + torch.testing.make_tensor( + (*batch_dims, *size), + low=0, + high=num_categories, + dtype=dtype or torch.uint8, + device=device, + ) + ) - def fn(shape, dtype, device): - data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=dtype, device=device) - return datapoints.Mask(data) - return MaskLoader(fn, shape=(*extra_dims, *size), dtype=dtype) +def make_segmentation_mask_loader( + size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_categories=10, extra_dims=(), dtype=torch.uint8 +): + # This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values + canvas_size = _parse_canvas_size(size) + def fn(shape, dtype, device): + *batch_dims, height, width = shape + return make_segmentation_mask( + (height, width), num_categories=num_categories, batch_dims=batch_dims, dtype=dtype, device=device + ) -make_segmentation_mask = from_loader(make_segmentation_mask_loader) + return MaskLoader(fn, shape=(*extra_dims, *canvas_size), dtype=dtype) def make_segmentation_mask_loaders( *, sizes=DEFAULT_SPATIAL_SIZES, - num_categories=(1, 2, "random"), + num_categories=(1, 2, 10), extra_dims=DEFAULT_EXTRA_DIMS, dtypes=(torch.uint8,), ): @@ -743,8 +807,8 @@ def make_segmentation_mask_loaders( def make_mask_loaders( *, sizes=DEFAULT_SPATIAL_SIZES, - num_objects=(1, 0, "random"), - num_categories=(1, 2, "random"), + num_objects=(1, 0, 5), + num_categories=(1, 2, 10), extra_dims=DEFAULT_EXTRA_DIMS, dtypes=(torch.uint8,), ): @@ -761,29 +825,35 @@ class VideoLoader(ImageLoader): pass +def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs): + return datapoints.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs)) + + def make_video_loader( - size="random", + size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, color_space="RGB", - num_frames="random", + num_frames=3, extra_dims=(), dtype=torch.uint8, ): - size = _parse_spatial_size(size) - num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames + size = _parse_canvas_size(size) def fn(shape, dtype, device, memory_format): - video = make_image( - size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device, memory_format=memory_format + *batch_dims, num_frames, _, height, width = shape + return make_video( + (height, width), + num_frames=num_frames, + batch_dims=batch_dims, + color_space=color_space, + dtype=dtype, + device=device, + memory_format=memory_format, ) - return datapoints.Video(video) return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype) -make_video = from_loader(make_video_loader) - - def make_video_loaders( *, sizes=DEFAULT_SPATIAL_SIZES, @@ -791,7 +861,7 @@ def make_video_loaders( "GRAY", "RGB", ), - num_frames=(1, 0, "random"), + num_frames=(1, 0, 3), extra_dims=DEFAULT_EXTRA_DIMS, dtypes=(torch.uint8, torch.float32, torch.float64), ): diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 169437a7424..ab325a8062e 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -571,7 +571,7 @@ def test_transforms_v2_wrapper(self, config): from torchvision.datasets import wrap_dataset_for_transforms_v2 try: - with self.create_dataset(config) as (dataset, _): + with self.create_dataset(config) as (dataset, info): for target_keys in [None, "all"]: if target_keys is not None and self.DATASET_CLASS not in { torchvision.datasets.CocoDetection, @@ -584,8 +584,10 @@ def test_transforms_v2_wrapper(self, config): continue wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) - wrapped_sample = wrapped_dataset[0] + assert isinstance(wrapped_dataset, self.DATASET_CLASS) + assert len(wrapped_dataset) == info["num_examples"] + wrapped_sample = wrapped_dataset[0] assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) except TypeError as error: msg = f"No wrapper exists for dataset class {type(dataset).__name__}" diff --git a/test/test_datapoints.py b/test/test_datapoints.py index 1334fd7283b..f0a44ec1720 100644 --- a/test/test_datapoints.py +++ b/test/test_datapoints.py @@ -27,7 +27,7 @@ def test_mask_instance(data): "format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH] ) def test_bbox_instance(data, format): - bboxes = datapoints.BoundingBox(data, format=format, spatial_size=(32, 32)) + bboxes = datapoints.BoundingBoxes(data, format=format, canvas_size=(32, 32)) assert isinstance(bboxes, torch.Tensor) assert bboxes.ndim == 2 and bboxes.shape[1] == 4 if isinstance(format, str): @@ -164,7 +164,7 @@ def test_wrap_like(): [ datapoints.Image(torch.rand(3, 16, 16)), datapoints.Video(torch.rand(2, 3, 16, 16)), - datapoints.BoundingBox([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10)), + datapoints.BoundingBoxes([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(10, 10)), datapoints.Mask(torch.randint(0, 256, (16, 16), dtype=torch.uint8)), ], ) diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 0866cc0f8a3..96a3fc5f8ed 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -103,17 +103,18 @@ def test_weights_deserializable(name): assert pickle.loads(pickle.dumps(weights)) is weights +def get_models_from_module(module): + return [ + v.__name__ + for k, v in module.__dict__.items() + if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__ + ] + + @pytest.mark.parametrize( "module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow] ) def test_list_models(module): - def get_models_from_module(module): - return [ - v.__name__ - for k, v in module.__dict__.items() - if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__ - ] - a = set(get_models_from_module(module)) b = set(x.replace("quantized_", "") for x in models.list_models(module)) @@ -121,6 +122,65 @@ def get_models_from_module(module): assert a == b +@pytest.mark.parametrize( + "include_filters", + [ + None, + [], + (), + "", + "*resnet*", + ["*alexnet*"], + "*not-existing-model-for-test?", + ["*resnet*", "*alexnet*"], + ["*resnet*", "*alexnet*", "*not-existing-model-for-test?"], + ("*resnet*", "*alexnet*"), + set(["*resnet*", "*alexnet*"]), + ], +) +@pytest.mark.parametrize( + "exclude_filters", + [ + None, + [], + (), + "", + "*resnet*", + ["*alexnet*"], + ["*not-existing-model-for-test?"], + ["resnet34", "*not-existing-model-for-test?"], + ["resnet34", "*resnet1*"], + ("resnet34", "*resnet1*"), + set(["resnet34", "*resnet1*"]), + ], +) +def test_list_models_filters(include_filters, exclude_filters): + actual = set(models.list_models(models, include=include_filters, exclude=exclude_filters)) + classification_models = set(get_models_from_module(models)) + + if isinstance(include_filters, str): + include_filters = [include_filters] + if isinstance(exclude_filters, str): + exclude_filters = [exclude_filters] + + if include_filters: + expected = set() + for include_f in include_filters: + include_f = include_f.strip("*?") + expected = expected | set(x for x in classification_models if include_f in x) + else: + expected = classification_models + + if exclude_filters: + for exclude_f in exclude_filters: + exclude_f = exclude_f.strip("*?") + if exclude_f != "": + a_exclude = set(x for x in classification_models if exclude_f in x) + expected = expected - a_exclude + + assert expected == actual + + @pytest.mark.parametrize( "name, weight", [ diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 255c3b5c32f..7bed48e6c15 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -20,7 +20,7 @@ from prototype_common_utils import make_label, make_one_hot_labels -from torchvision.datapoints import BoundingBox, BoundingBoxFormat, Image, Mask, Video +from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.prototype import datapoints, transforms from torchvision.transforms.v2._utils import _convert_fill_arg from torchvision.transforms.v2.functional import InterpolationMode, pil_to_tensor, to_image_pil @@ -60,8 +60,8 @@ def parametrize(transforms_with_inputs): ], ) for transform in [ - transforms.RandomMixup(alpha=1.0), - transforms.RandomCutmix(alpha=1.0), + transforms.RandomMixUp(alpha=1.0), + transforms.RandomCutMix(alpha=1.0), ] ] ) @@ -101,10 +101,10 @@ def test__extract_image_targets_assertion(self, mocker): self.create_fake_image(mocker, Image), # labels, bboxes, masks mocker.MagicMock(spec=datapoints.Label), - mocker.MagicMock(spec=BoundingBox), + mocker.MagicMock(spec=BoundingBoxes), mocker.MagicMock(spec=Mask), # labels, bboxes, masks - mocker.MagicMock(spec=BoundingBox), + mocker.MagicMock(spec=BoundingBoxes), mocker.MagicMock(spec=Mask), ] @@ -122,11 +122,11 @@ def test__extract_image_targets(self, image_type, label_type, mocker): self.create_fake_image(mocker, image_type), # labels, bboxes, masks mocker.MagicMock(spec=label_type), - mocker.MagicMock(spec=BoundingBox), + mocker.MagicMock(spec=BoundingBoxes), mocker.MagicMock(spec=Mask), # labels, bboxes, masks mocker.MagicMock(spec=label_type), - mocker.MagicMock(spec=BoundingBox), + mocker.MagicMock(spec=BoundingBoxes), mocker.MagicMock(spec=Mask), ] @@ -142,7 +142,7 @@ def test__extract_image_targets(self, image_type, label_type, mocker): for target in targets: for key, type_ in [ - ("boxes", BoundingBox), + ("boxes", BoundingBoxes), ("masks", Mask), ("labels", label_type), ]: @@ -163,8 +163,8 @@ def test__copy_paste(self, label_type): if label_type == datapoints.OneHotLabel: labels = torch.nn.functional.one_hot(labels, num_classes=5) target = { - "boxes": BoundingBox( - torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", spatial_size=(32, 32) + "boxes": BoundingBoxes( + torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", canvas_size=(32, 32) ), "masks": Mask(masks), "labels": label_type(labels), @@ -178,8 +178,8 @@ def test__copy_paste(self, label_type): if label_type == datapoints.OneHotLabel: paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5) paste_target = { - "boxes": BoundingBox( - torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", spatial_size=(32, 32) + "boxes": BoundingBoxes( + torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", canvas_size=(32, 32) ), "masks": Mask(paste_masks), "labels": label_type(paste_labels), @@ -210,13 +210,13 @@ class TestFixedSizeCrop: def test__get_params(self, mocker): crop_size = (7, 7) batch_shape = (10,) - spatial_size = (11, 5) + canvas_size = (11, 5) transform = transforms.FixedSizeCrop(size=crop_size) flat_inputs = [ - make_image(size=spatial_size, color_space="RGB"), - make_bounding_box(format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape), + make_image(size=canvas_size, color_space="RGB"), + make_bounding_box(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=batch_shape), ] params = transform._get_params(flat_inputs) @@ -295,7 +295,7 @@ def test__transform(self, mocker, needs): def test__transform_culling(self, mocker): batch_size = 10 - spatial_size = (10, 10) + canvas_size = (10, 10) is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool) mocker.patch( @@ -304,17 +304,17 @@ def test__transform_culling(self, mocker): needs_crop=True, top=0, left=0, - height=spatial_size[0], - width=spatial_size[1], + height=canvas_size[0], + width=canvas_size[1], is_valid=is_valid, needs_pad=False, ), ) bounding_boxes = make_bounding_box( - format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,) + format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,) ) - masks = make_detection_mask(size=spatial_size, extra_dims=(batch_size,)) + masks = make_detection_mask(size=canvas_size, batch_dims=(batch_size,)) labels = make_label(extra_dims=(batch_size,)) transform = transforms.FixedSizeCrop((-1, -1)) @@ -332,9 +332,9 @@ def test__transform_culling(self, mocker): assert_equal(output["masks"], masks[is_valid]) assert_equal(output["labels"], labels[is_valid]) - def test__transform_bounding_box_clamping(self, mocker): + def test__transform_bounding_boxes_clamping(self, mocker): batch_size = 3 - spatial_size = (10, 10) + canvas_size = (10, 10) mocker.patch( "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", @@ -342,22 +342,22 @@ def test__transform_bounding_box_clamping(self, mocker): needs_crop=True, top=0, left=0, - height=spatial_size[0], - width=spatial_size[1], + height=canvas_size[0], + width=canvas_size[1], is_valid=torch.full((batch_size,), fill_value=True), needs_pad=False, ), ) - bounding_box = make_bounding_box( - format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,) + bounding_boxes = make_bounding_box( + format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,) ) - mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") + mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes") transform = transforms.FixedSizeCrop((-1, -1)) mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) - transform(bounding_box) + transform(bounding_boxes) mock.assert_called_once() @@ -390,7 +390,7 @@ class TestPermuteDimensions: def test_call(self, dims, inverse_dims): sample = dict( image=make_image(), - bounding_box=make_bounding_box(format=BoundingBoxFormat.XYXY), + bounding_boxes=make_bounding_box(format=BoundingBoxFormat.XYXY), video=make_video(), str="str", int=0, @@ -434,7 +434,7 @@ class TestTransposeDimensions: def test_call(self, dims): sample = dict( image=make_image(), - bounding_box=make_bounding_box(format=BoundingBoxFormat.XYXY), + bounding_boxes=make_bounding_box(format=BoundingBoxFormat.XYXY), video=make_video(), str="str", int=0, @@ -496,7 +496,7 @@ def make_datapoints(): pil_image = to_image_pil(make_image(size=size, color_space="RGB")) target = { - "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), + "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), } @@ -505,7 +505,7 @@ def make_datapoints(): tensor_image = torch.Tensor(make_image(size=size, color_space="RGB")) target = { - "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), + "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), } @@ -514,7 +514,7 @@ def make_datapoints(): datapoint_image = make_image(size=size, color_space="RGB") target = { - "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), + "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), } diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 093c378aa72..d5f448b09aa 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1,10 +1,8 @@ import itertools import pathlib import random -import re import textwrap import warnings -from collections import defaultdict import numpy as np @@ -47,8 +45,8 @@ def make_pil_images(*args, **kwargs): def make_vanilla_tensor_bounding_boxes(*args, **kwargs): - for bounding_box in make_bounding_boxes(*args, **kwargs): - yield bounding_box.data + for bounding_boxes in make_bounding_boxes(*args, **kwargs): + yield bounding_boxes.data def parametrize(transforms_with_inputs): @@ -70,7 +68,7 @@ def auto_augment_adapter(transform, input, device): adapted_input = {} image_or_video_found = False for key, value in input.items(): - if isinstance(value, (datapoints.BoundingBox, datapoints.Mask)): + if isinstance(value, (datapoints.BoundingBoxes, datapoints.Mask)): # AA transforms don't support bounding boxes or masks continue elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)): @@ -105,7 +103,7 @@ def normalize_adapter(transform, input, device): continue elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)): # normalize doesn't support integer images - value = F.convert_dtype(value, torch.float32) + value = F.to_dtype(value, torch.float32, scale=True) adapted_input[key] = value return adapted_input @@ -144,9 +142,9 @@ class TestSmoke: (transforms.RandomZoomOut(p=1.0), None), (transforms.Resize([16, 16], antialias=True), None), (transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None), - (transforms.ClampBoundingBox(), None), + (transforms.ClampBoundingBoxes(), None), (transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None), - (transforms.ConvertDtype(), None), + (transforms.ConvertImageDtype(), None), (transforms.GaussianBlur(kernel_size=3), None), ( transforms.LinearTransformation( @@ -175,22 +173,22 @@ class TestSmoke: ) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_common(self, transform, adapter, container_type, image_or_video, device): - spatial_size = F.get_spatial_size(image_or_video) + canvas_size = F.get_size(image_or_video) input = dict( image_or_video=image_or_video, - image_datapoint=make_image(size=spatial_size), - video_datapoint=make_video(size=spatial_size), - image_pil=next(make_pil_images(sizes=[spatial_size], color_spaces=["RGB"])), - bounding_box_xyxy=make_bounding_box( - format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(3,) + image_datapoint=make_image(size=canvas_size), + video_datapoint=make_video(size=canvas_size), + image_pil=next(make_pil_images(sizes=[canvas_size], color_spaces=["RGB"])), + bounding_boxes_xyxy=make_bounding_box( + format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(3,) ), - bounding_box_xywh=make_bounding_box( - format=datapoints.BoundingBoxFormat.XYWH, spatial_size=spatial_size, extra_dims=(4,) + bounding_boxes_xywh=make_bounding_box( + format=datapoints.BoundingBoxFormat.XYWH, canvas_size=canvas_size, batch_dims=(4,) ), - bounding_box_cxcywh=make_bounding_box( - format=datapoints.BoundingBoxFormat.CXCYWH, spatial_size=spatial_size, extra_dims=(5,) + bounding_boxes_cxcywh=make_bounding_box( + format=datapoints.BoundingBoxFormat.CXCYWH, canvas_size=canvas_size, batch_dims=(5,) ), - bounding_box_degenerate_xyxy=datapoints.BoundingBox( + bounding_boxes_degenerate_xyxy=datapoints.BoundingBoxes( [ [0, 0, 0, 0], # no height or width [0, 0, 0, 1], # no height @@ -200,9 +198,9 @@ def test_common(self, transform, adapter, container_type, image_or_video, device [2, 2, 1, 1], # x1 > x2, y1 > y2 ], format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=spatial_size, + canvas_size=canvas_size, ), - bounding_box_degenerate_xywh=datapoints.BoundingBox( + bounding_boxes_degenerate_xywh=datapoints.BoundingBoxes( [ [0, 0, 0, 0], # no height or width [0, 0, 0, 1], # no height @@ -212,9 +210,9 @@ def test_common(self, transform, adapter, container_type, image_or_video, device [0, 0, -1, -1], # negative height and width ], format=datapoints.BoundingBoxFormat.XYWH, - spatial_size=spatial_size, + canvas_size=canvas_size, ), - bounding_box_degenerate_cxcywh=datapoints.BoundingBox( + bounding_boxes_degenerate_cxcywh=datapoints.BoundingBoxes( [ [0, 0, 0, 0], # no height or width [0, 0, 0, 1], # no height @@ -224,10 +222,10 @@ def test_common(self, transform, adapter, container_type, image_or_video, device [0, 0, -1, -1], # negative height and width ], format=datapoints.BoundingBoxFormat.CXCYWH, - spatial_size=spatial_size, + canvas_size=canvas_size, ), - detection_mask=make_detection_mask(size=spatial_size), - segmentation_mask=make_segmentation_mask(size=spatial_size), + detection_mask=make_detection_mask(size=canvas_size), + segmentation_mask=make_segmentation_mask(size=canvas_size), int=0, float=0.0, bool=True, @@ -262,7 +260,7 @@ def test_common(self, transform, adapter, container_type, image_or_video, device else: assert output_item is input_item - if isinstance(input_item, datapoints.BoundingBox) and not isinstance( + if isinstance(input_item, datapoints.BoundingBoxes) and not isinstance( transform, transforms.ConvertBoundingBoxFormat ): assert output_item.format == input_item.format @@ -272,10 +270,10 @@ def test_common(self, transform, adapter, container_type, image_or_video, device # TODO: we should test that against all degenerate boxes above for format in list(datapoints.BoundingBoxFormat): sample = dict( - boxes=datapoints.BoundingBox([[0, 0, 0, 0]], format=format, spatial_size=(224, 244)), + boxes=datapoints.BoundingBoxes([[0, 0, 0, 0]], format=format, canvas_size=(224, 244)), labels=torch.tensor([3]), ) - assert transforms.SanitizeBoundingBox()(sample)["boxes"].shape == (0, 4) + assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4) @parametrize( [ @@ -289,7 +287,7 @@ def test_common(self, transform, adapter, container_type, image_or_video, device ], dtypes=[torch.uint8], extra_dims=[(), (4,)], - **(dict(num_frames=["random"]) if fn is make_videos else dict()), + **(dict(num_frames=[3]) if fn is make_videos else dict()), ) for fn in [ make_images, @@ -474,11 +472,11 @@ def test_assertions(self): @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) - def test__get_params(self, fill, side_range, mocker): + def test__get_params(self, fill, side_range): transform = transforms.RandomZoomOut(fill=fill, side_range=side_range) - image = mocker.MagicMock(spec=datapoints.Image) - h, w = image.spatial_size = (24, 32) + h, w = size = (24, 32) + image = make_image(size) params = transform._get_params([image]) @@ -491,9 +489,7 @@ def test__get_params(self, fill, side_range, mocker): @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) def test__transform(self, fill, side_range, mocker): - inpt = mocker.MagicMock(spec=datapoints.Image) - inpt.num_channels = 3 - inpt.spatial_size = (24, 32) + inpt = make_image((24, 32)) transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1) @@ -560,11 +556,9 @@ def test_assertions(self): @pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]]) @pytest.mark.parametrize("size, pad_if_needed", [((10, 10), False), ((50, 25), True)]) - def test__get_params(self, padding, pad_if_needed, size, mocker): - image = mocker.MagicMock(spec=datapoints.Image) - image.num_channels = 3 - image.spatial_size = (24, 32) - h, w = image.spatial_size + def test__get_params(self, padding, pad_if_needed, size): + h, w = size = (24, 32) + image = make_image(size) transform = transforms.RandomCrop(size, padding=padding, pad_if_needed=pad_if_needed) params = transform._get_params([image]) @@ -614,21 +608,16 @@ def test__transform(self, padding, pad_if_needed, fill, padding_mode, mocker): output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode ) - inpt = mocker.MagicMock(spec=datapoints.Image) - inpt.num_channels = 3 - inpt.spatial_size = (32, 32) + h, w = size = (32, 32) + inpt = make_image(size) - expected = mocker.MagicMock(spec=datapoints.Image) - expected.num_channels = 3 if isinstance(padding, int): - expected.spatial_size = (inpt.spatial_size[0] + padding, inpt.spatial_size[1] + padding) + new_size = (h + padding, w + padding) elif isinstance(padding, list): - expected.spatial_size = ( - inpt.spatial_size[0] + sum(padding[0::2]), - inpt.spatial_size[1] + sum(padding[1::2]), - ) + new_size = (h + sum(padding[0::2]), w + sum(padding[1::2])) else: - expected.spatial_size = inpt.spatial_size + new_size = size + expected = make_image(new_size) _ = mocker.patch("torchvision.transforms.v2.functional.pad", return_value=expected) fn_crop = mocker.patch("torchvision.transforms.v2.functional.crop") @@ -704,7 +693,7 @@ def test__transform(self, kernel_size, sigma, mocker): fn = mocker.patch("torchvision.transforms.v2.functional.gaussian_blur") inpt = mocker.MagicMock(spec=datapoints.Image) inpt.num_channels = 3 - inpt.spatial_size = (24, 32) + inpt.canvas_size = (24, 32) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users @@ -750,16 +739,14 @@ def test_assertions(self): with pytest.raises(TypeError, match="Got inappropriate fill arg"): transforms.RandomPerspective(0.5, fill="abc") - def test__get_params(self, mocker): + def test__get_params(self): dscale = 0.5 transform = transforms.RandomPerspective(dscale) - image = mocker.MagicMock(spec=datapoints.Image) - image.num_channels = 3 - image.spatial_size = (24, 32) + + image = make_image((24, 32)) params = transform._get_params([image]) - h, w = image.spatial_size assert "coefficients" in params assert len(params["coefficients"]) == 8 @@ -770,9 +757,9 @@ def test__transform(self, distortion_scale, mocker): transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation) fn = mocker.patch("torchvision.transforms.v2.functional.perspective") - inpt = mocker.MagicMock(spec=datapoints.Image) - inpt.num_channels = 3 - inpt.spatial_size = (24, 32) + + inpt = make_image((24, 32)) + # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users # Otherwise, we can mock transform._get_params @@ -810,17 +797,16 @@ def test_assertions(self): with pytest.raises(TypeError, match="Got inappropriate fill arg"): transforms.ElasticTransform(1.0, 2.0, fill="abc") - def test__get_params(self, mocker): + def test__get_params(self): alpha = 2.0 sigma = 3.0 transform = transforms.ElasticTransform(alpha, sigma) - image = mocker.MagicMock(spec=datapoints.Image) - image.num_channels = 3 - image.spatial_size = (24, 32) + + h, w = size = (24, 32) + image = make_image(size) params = transform._get_params([image]) - h, w = image.spatial_size displacement = params["displacement"] assert displacement.shape == (1, h, w, 2) assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all() @@ -846,7 +832,7 @@ def test__transform(self, alpha, sigma, mocker): fn = mocker.patch("torchvision.transforms.v2.functional.elastic") inpt = mocker.MagicMock(spec=datapoints.Image) inpt.num_channels = 3 - inpt.spatial_size = (24, 32) + inpt.canvas_size = (24, 32) # Let's mock transform._get_params to control the output: transform._get_params = mocker.MagicMock() @@ -857,7 +843,7 @@ def test__transform(self, alpha, sigma, mocker): class TestRandomErasing: - def test_assertions(self, mocker): + def test_assertions(self): with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"): transforms.RandomErasing(value={}) @@ -873,9 +859,7 @@ def test_assertions(self, mocker): with pytest.raises(ValueError, match="Scale should be between 0 and 1"): transforms.RandomErasing(scale=[-1, 2]) - image = mocker.MagicMock(spec=datapoints.Image) - image.num_channels = 3 - image.spatial_size = (24, 32) + image = make_image((24, 32)) transform = transforms.RandomErasing(value=[1, 2, 3, 4]) @@ -883,10 +867,9 @@ def test_assertions(self, mocker): transform._get_params([image]) @pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"]) - def test__get_params(self, value, mocker): - image = mocker.MagicMock(spec=datapoints.Image) - image.num_channels = 3 - image.spatial_size = (24, 32) + def test__get_params(self, value): + image = make_image((24, 32)) + num_channels, height, width = F.get_dimensions(image) transform = transforms.RandomErasing(value=value) params = transform._get_params([image]) @@ -896,14 +879,14 @@ def test__get_params(self, value, mocker): i, j = params["i"], params["j"] assert isinstance(v, torch.Tensor) if value == "random": - assert v.shape == (image.num_channels, h, w) + assert v.shape == (num_channels, h, w) elif isinstance(value, (int, float)): assert v.shape == (1, 1, 1) elif isinstance(value, (list, tuple)): - assert v.shape == (image.num_channels, 1, 1) + assert v.shape == (num_channels, 1, 1) - assert 0 <= i <= image.spatial_size[0] - h - assert 0 <= j <= image.spatial_size[1] - w + assert 0 <= i <= height - h + assert 0 <= j <= width - w @pytest.mark.parametrize("p", [0, 1]) def test__transform(self, mocker, p): @@ -943,7 +926,7 @@ def test__transform(self, mocker, p): class TestTransform: @pytest.mark.parametrize( "inpt_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], + [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], ) def test_check_transformed_types(self, inpt_type, mocker): # This test ensures that we correctly handle which types to transform and which to bypass @@ -961,7 +944,7 @@ def test_check_transformed_types(self, inpt_type, mocker): class TestToImageTensor: @pytest.mark.parametrize( "inpt_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], + [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], ) def test__transform(self, inpt_type, mocker): fn = mocker.patch( @@ -972,7 +955,7 @@ def test__transform(self, inpt_type, mocker): inpt = mocker.MagicMock(spec=inpt_type) transform = transforms.ToImageTensor() transform(inpt) - if inpt_type in (datapoints.BoundingBox, datapoints.Image, str, int): + if inpt_type in (datapoints.BoundingBoxes, datapoints.Image, str, int): assert fn.call_count == 0 else: fn.assert_called_once_with(inpt) @@ -981,7 +964,7 @@ def test__transform(self, inpt_type, mocker): class TestToImagePIL: @pytest.mark.parametrize( "inpt_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], + [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], ) def test__transform(self, inpt_type, mocker): fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil") @@ -989,7 +972,7 @@ def test__transform(self, inpt_type, mocker): inpt = mocker.MagicMock(spec=inpt_type) transform = transforms.ToImagePIL() transform(inpt) - if inpt_type in (datapoints.BoundingBox, PIL.Image.Image, str, int): + if inpt_type in (datapoints.BoundingBoxes, PIL.Image.Image, str, int): assert fn.call_count == 0 else: fn.assert_called_once_with(inpt, mode=transform.mode) @@ -998,7 +981,7 @@ def test__transform(self, inpt_type, mocker): class TestToPILImage: @pytest.mark.parametrize( "inpt_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], + [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], ) def test__transform(self, inpt_type, mocker): fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil") @@ -1006,7 +989,7 @@ def test__transform(self, inpt_type, mocker): inpt = mocker.MagicMock(spec=inpt_type) transform = transforms.ToPILImage() transform(inpt) - if inpt_type in (PIL.Image.Image, datapoints.BoundingBox, str, int): + if inpt_type in (PIL.Image.Image, datapoints.BoundingBoxes, str, int): assert fn.call_count == 0 else: fn.assert_called_once_with(inpt, mode=transform.mode) @@ -1015,7 +998,7 @@ def test__transform(self, inpt_type, mocker): class TestToTensor: @pytest.mark.parametrize( "inpt_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], + [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], ) def test__transform(self, inpt_type, mocker): fn = mocker.patch("torchvision.transforms.functional.to_tensor") @@ -1024,7 +1007,7 @@ def test__transform(self, inpt_type, mocker): with pytest.warns(UserWarning, match="deprecated and will be removed"): transform = transforms.ToTensor() transform(inpt) - if inpt_type in (datapoints.Image, torch.Tensor, datapoints.BoundingBox, str, int): + if inpt_type in (datapoints.Image, torch.Tensor, datapoints.BoundingBoxes, str, int): assert fn.call_count == 0 else: fn.assert_called_once_with(inpt) @@ -1062,14 +1045,13 @@ def test_assertions(self): class TestRandomIoUCrop: @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]]) - def test__get_params(self, device, options, mocker): - image = mocker.MagicMock(spec=datapoints.Image) - image.num_channels = 3 - image.spatial_size = (24, 32) - bboxes = datapoints.BoundingBox( + def test__get_params(self, device, options): + orig_h, orig_w = size = (24, 32) + image = make_image(size) + bboxes = datapoints.BoundingBoxes( torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]), format="XYXY", - spatial_size=image.spatial_size, + canvas_size=size, device=device, ) sample = [image, bboxes] @@ -1088,8 +1070,6 @@ def test__get_params(self, device, options, mocker): assert len(params["is_within_crop_area"]) > 0 assert params["is_within_crop_area"].dtype == torch.bool - orig_h = image.spatial_size[0] - orig_w = image.spatial_size[1] assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h) assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w) @@ -1104,7 +1084,7 @@ def test__get_params(self, device, options, mocker): def test__transform_empty_params(self, mocker): transform = transforms.RandomIoUCrop(sampler_options=[2.0]) image = datapoints.Image(torch.rand(1, 3, 4, 4)) - bboxes = datapoints.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4)) + bboxes = datapoints.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4)) label = torch.tensor([1]) sample = [image, bboxes, label] # Let's mock transform._get_params to control the output: @@ -1123,9 +1103,10 @@ def test_forward_assertion(self): def test__transform(self, mocker): transform = transforms.RandomIoUCrop() - image = datapoints.Image(torch.rand(3, 32, 24)) - bboxes = make_bounding_box(format="XYXY", spatial_size=(32, 24), extra_dims=(6,)) - masks = make_detection_mask((32, 24), num_objects=6) + size = (32, 24) + image = make_image(size) + bboxes = make_bounding_box(format="XYXY", canvas_size=size, batch_dims=(6,)) + masks = make_detection_mask(size, num_objects=6) sample = [image, bboxes, masks] @@ -1148,7 +1129,7 @@ def test__transform(self, mocker): # check number of bboxes vs number of labels: output_bboxes = output[1] - assert isinstance(output_bboxes, datapoints.BoundingBox) + assert isinstance(output_bboxes, datapoints.BoundingBoxes) assert (output_bboxes[~is_within_crop_area] == 0).all() output_masks = output[2] @@ -1156,13 +1137,14 @@ def test__transform(self, mocker): class TestScaleJitter: - def test__get_params(self, mocker): - spatial_size = (24, 32) + def test__get_params(self): + canvas_size = (24, 32) target_size = (16, 12) scale_range = (0.5, 1.5) transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range) - sample = mocker.MagicMock(spec=datapoints.Image, num_channels=3, spatial_size=spatial_size) + + sample = make_image(canvas_size) n_samples = 5 for _ in range(n_samples): @@ -1175,11 +1157,11 @@ def test__get_params(self, mocker): assert isinstance(size, tuple) and len(size) == 2 height, width = size - r_min = min(target_size[1] / spatial_size[0], target_size[0] / spatial_size[1]) * scale_range[0] - r_max = min(target_size[1] / spatial_size[0], target_size[0] / spatial_size[1]) * scale_range[1] + r_min = min(target_size[1] / canvas_size[0], target_size[0] / canvas_size[1]) * scale_range[0] + r_max = min(target_size[1] / canvas_size[0], target_size[0] / canvas_size[1]) * scale_range[1] - assert int(spatial_size[0] * r_min) <= height <= int(spatial_size[0] * r_max) - assert int(spatial_size[1] * r_min) <= width <= int(spatial_size[1] * r_max) + assert int(canvas_size[0] * r_min) <= height <= int(canvas_size[0] * r_max) + assert int(canvas_size[1] * r_min) <= width <= int(canvas_size[1] * r_max) def test__transform(self, mocker): interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode) @@ -1207,12 +1189,12 @@ def test__transform(self, mocker): class TestRandomShortestSize: @pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)]) - def test__get_params(self, min_size, max_size, mocker): - spatial_size = (3, 10) + def test__get_params(self, min_size, max_size): + canvas_size = (3, 10) transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size, antialias=True) - sample = mocker.MagicMock(spec=datapoints.Image, num_channels=3, spatial_size=spatial_size) + sample = make_image(canvas_size) params = transform._get_params([sample]) assert "size" in params @@ -1326,61 +1308,6 @@ def test__transform(self, mocker): ) -class TestToDtype: - @pytest.mark.parametrize( - ("dtype", "expected_dtypes"), - [ - ( - torch.float64, - { - datapoints.Video: torch.float64, - datapoints.Image: torch.float64, - datapoints.BoundingBox: torch.float64, - }, - ), - ( - {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, - {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, - ), - ], - ) - def test_call(self, dtype, expected_dtypes): - sample = dict( - video=make_video(dtype=torch.int64), - image=make_image(dtype=torch.uint8), - bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32), - str="str", - int=0, - ) - - transform = transforms.ToDtype(dtype) - transformed_sample = transform(sample) - - for key, value in sample.items(): - value_type = type(value) - transformed_value = transformed_sample[key] - - # make sure the transformation retains the type - assert isinstance(transformed_value, value_type) - - if isinstance(value, torch.Tensor): - assert transformed_value.dtype is expected_dtypes[value_type] - else: - assert transformed_value is value - - @pytest.mark.filterwarnings("error") - def test_plain_tensor_call(self): - tensor = torch.empty((), dtype=torch.float32) - transform = transforms.ToDtype({torch.Tensor: torch.float64}) - - assert transform(tensor).dtype is torch.float64 - - @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) - def test_plain_tensor_warning(self, other_type): - with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): - transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64}) - - class TestUniformTemporalSubsample: @pytest.mark.parametrize( "inpt", @@ -1547,7 +1474,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): elif data_augmentation == "ssd": t = [ transforms.RandomPhotometricDistort(p=1), - transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}), p=1), + transforms.RandomZoomOut(fill={"others": (123.0, 117.0, 104.0), datapoints.Mask: 0}, p=1), transforms.RandomIoUCrop(), transforms.RandomHorizontalFlip(p=1), to_tensor, @@ -1561,7 +1488,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): transforms.ConvertImageDtype(torch.float), ] if sanitize: - t += [transforms.SanitizeBoundingBox()] + t += [transforms.SanitizeBoundingBoxes()] t = transforms.Compose(t) num_boxes = 5 @@ -1579,7 +1506,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4)) boxes[:, 2:] += boxes[:, :2] boxes = boxes.clamp(min=0, max=min(H, W)) - boxes = datapoints.BoundingBox(boxes, format="XYXY", spatial_size=(H, W)) + boxes = datapoints.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W)) masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8)) @@ -1602,7 +1529,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): # ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It # doesn't remove them strictly speaking, it just marks some boxes as # degenerate and those boxes will be later removed by - # SanitizeBoundingBox(), which we add to the pipelines if the sanitize + # SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize # param is True. # Note that the values below are probably specific to the random seed # set above (which is fine). @@ -1614,9 +1541,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): @pytest.mark.parametrize("min_size", (1, 10)) -@pytest.mark.parametrize( - "labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None) -) +@pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None)) @pytest.mark.parametrize("sample_type", (tuple, dict)) def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): @@ -1652,10 +1577,10 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): boxes = torch.tensor(boxes) labels = torch.arange(boxes.shape[0]) - boxes = datapoints.BoundingBox( + boxes = datapoints.BoundingBoxes( boxes, format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=(H, W), + canvas_size=(H, W), ) masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W))) @@ -1674,7 +1599,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): img = sample.pop("image") sample = (img, sample) - out = transforms.SanitizeBoundingBox(min_size=min_size, labels_getter=labels_getter)(sample) + out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample) if sample_type is tuple: out_image = out[0] @@ -1692,7 +1617,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): assert out_image is input_img assert out_whatever is whatever - assert isinstance(out_boxes, datapoints.BoundingBox) + assert isinstance(out_boxes, datapoints.BoundingBoxes) assert isinstance(out_masks, datapoints.Mask) if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None): @@ -1704,62 +1629,42 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): assert out_labels.tolist() == valid_indices -@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT")) -@pytest.mark.parametrize("sample_type", (tuple, dict)) -def test_sanitize_bounding_boxes_default_heuristic(key, sample_type): - labels = torch.arange(10) - sample = {key: labels, "another_key": "whatever"} - if sample_type is tuple: - sample = (None, sample, "whatever_again") - assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(sample) is labels - - if key.lower() != "labels": - # If "labels" is in the dict (case-insensitive), - # it takes precedence over other keys which would otherwise be a match - d = {key: "something_else", "labels": labels} - assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(d) is labels - - def test_sanitize_bounding_boxes_errors(): - good_bbox = datapoints.BoundingBox( + good_bbox = datapoints.BoundingBoxes( [[0, 0, 10, 10]], format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=(20, 20), + canvas_size=(20, 20), ) with pytest.raises(ValueError, match="min_size must be >= 1"): - transforms.SanitizeBoundingBox(min_size=0) - with pytest.raises(ValueError, match="labels_getter should either be a str"): - transforms.SanitizeBoundingBox(labels_getter=12) + transforms.SanitizeBoundingBoxes(min_size=0) + with pytest.raises(ValueError, match="labels_getter should either be 'default'"): + transforms.SanitizeBoundingBoxes(labels_getter=12) with pytest.raises(ValueError, match="Could not infer where the labels are"): bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])} - transforms.SanitizeBoundingBox()(bad_labels_key) - - with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"): - not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0])) - transforms.SanitizeBoundingBox()(not_a_dict) + transforms.SanitizeBoundingBoxes()(bad_labels_key) with pytest.raises(ValueError, match="must be a tensor"): not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()} - transforms.SanitizeBoundingBox()(not_a_tensor) + transforms.SanitizeBoundingBoxes()(not_a_tensor) with pytest.raises(ValueError, match="Number of boxes"): different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} - transforms.SanitizeBoundingBox()(different_sizes) + transforms.SanitizeBoundingBoxes()(different_sizes) with pytest.raises(ValueError, match="boxes must be of shape"): - bad_bbox = datapoints.BoundingBox( # batch with 2 elements + bad_bbox = datapoints.BoundingBoxes( # batch with 2 elements [ [[0, 0, 10, 10]], [[0, 0, 10, 10]], ], format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=(20, 20), + canvas_size=(20, 20), ) different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])} - transforms.SanitizeBoundingBox()(different_sizes) + transforms.SanitizeBoundingBoxes()(different_sizes) @pytest.mark.parametrize( diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index f035dde45ed..f5ea69279a1 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -4,7 +4,6 @@ import inspect import random import re -from collections import defaultdict from pathlib import Path import numpy as np @@ -30,8 +29,9 @@ from torchvision.transforms import functional as legacy_F from torchvision.transforms.v2 import functional as prototype_F +from torchvision.transforms.v2._utils import _get_fill from torchvision.transforms.v2.functional import to_image_pil -from torchvision.transforms.v2.utils import query_spatial_size +from torchvision.transforms.v2.utils import query_size DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)]) @@ -191,7 +191,7 @@ def __init__( closeness_kwargs=dict(rtol=None, atol=None), ), ConsistencyConfig( - v2_transforms.ConvertDtype, + v2_transforms.ConvertImageDtype, legacy_transforms.ConvertImageDtype, [ ArgsKwargs(torch.float16), @@ -1090,7 +1090,7 @@ def make_label(extra_dims, categories): pil_image = to_image_pil(make_image(size=size, color_space="RGB")) target = { - "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), + "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), } if with_mask: @@ -1098,9 +1098,9 @@ def make_label(extra_dims, categories): yield (pil_image, target) - tensor_image = torch.Tensor(make_image(size=size, color_space="RGB")) + tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32)) target = { - "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), + "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), } if with_mask: @@ -1108,9 +1108,9 @@ def make_label(extra_dims, categories): yield (tensor_image, target) - datapoint_image = make_image(size=size, color_space="RGB") + datapoint_image = make_image(size=size, color_space="RGB", dtype=torch.float32) target = { - "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), + "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), } if with_mask: @@ -1127,13 +1127,13 @@ def make_label(extra_dims, categories): v2_transforms.Compose( [ v2_transforms.RandomIoUCrop(), - v2_transforms.SanitizeBoundingBox(labels_getter=lambda sample: sample[1]["labels"]), + v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]), ] ), {"with_mask": False}, ), (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}), - (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024)), {}), + (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}), ( det_transforms.RandomShortestSize( min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 @@ -1172,7 +1172,7 @@ def __init__(self, size, fill=0): self.fill = v2_transforms._geometry._setup_fill_arg(fill) def _get_params(self, sample): - height, width = query_spatial_size(sample) + height, width = query_size(sample) padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] needs_padding = any(padding) return dict(padding=padding, needs_padding=needs_padding) @@ -1181,7 +1181,7 @@ def _transform(self, inpt, params): if not params["needs_padding"]: return inpt - fill = self.fill[type(inpt)] + fill = _get_fill(self.fill, type(inpt)) return prototype_F.pad(inpt, padding=params["padding"], fill=fill) @@ -1243,7 +1243,7 @@ def check(self, t, t_ref, data_kwargs=None): seg_transforms.RandomCrop(size=480), v2_transforms.Compose( [ - PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})), + PadIfSmaller(size=480, fill={datapoints.Mask: 255, "others": 0}), v2_transforms.RandomCrop(size=480), ] ), diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index 79ea20d854e..230695ff93e 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -26,7 +26,7 @@ from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding -from torchvision.transforms.v2.functional._meta import clamp_bounding_box, convert_format_bounding_box +from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes from torchvision.transforms.v2.utils import is_simple_tensor from transforms_v2_dispatcher_infos import DISPATCHER_INFOS from transforms_v2_kernel_infos import KERNEL_INFOS @@ -176,7 +176,7 @@ def test_batched_vs_single(self, test_id, info, args_kwargs, device): # Everything to the left is considered a batch dimension. data_dims = { datapoints.Image: 3, - datapoints.BoundingBox: 1, + datapoints.BoundingBoxes: 1, # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as @@ -283,12 +283,12 @@ def test_float32_vs_uint8(self, test_id, info, args_kwargs): adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs) actual = info.kernel( - F.convert_dtype_image_tensor(input, dtype=torch.float32), + F.to_dtype_image_tensor(input, dtype=torch.float32, scale=True), *adapted_other_args, **adapted_kwargs, ) - expected = F.convert_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32) + expected = F.to_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True) assert_close( actual, @@ -351,7 +351,7 @@ def test_scripted_smoke(self, info, args_kwargs, device): F.get_image_size, F.get_num_channels, F.get_num_frames, - F.get_spatial_size, + F.get_size, F.rgb_to_grayscale, F.uniform_temporal_subsample, ], @@ -515,15 +515,15 @@ def test_unkown_type(self, info): [ info for info in DISPATCHER_INFOS - if datapoints.BoundingBox in info.kernels and info.dispatcher is not F.convert_format_bounding_box + if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_format_bounding_boxes ], - args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBox), + args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes), ) - def test_bounding_box_format_consistency(self, info, args_kwargs): - (bounding_box, *other_args), kwargs = args_kwargs.load() - format = bounding_box.format + def test_bounding_boxes_format_consistency(self, info, args_kwargs): + (bounding_boxes, *other_args), kwargs = args_kwargs.load() + format = bounding_boxes.format - output = info.dispatcher(bounding_box, *other_args, **kwargs) + output = info.dispatcher(bounding_boxes, *other_args, **kwargs) assert output.format == format @@ -538,7 +538,6 @@ def test_bounding_box_format_consistency(self, info, args_kwargs): (F.get_image_num_channels, F.get_num_channels), (F.to_pil_image, F.to_image_pil), (F.elastic_transform, F.elastic), - (F.convert_image_dtype, F.convert_dtype_image_tensor), (F.to_grayscale, F.rgb_to_grayscale), ] ], @@ -547,24 +546,6 @@ def test_alias(alias, target): assert alias is target -@pytest.mark.parametrize( - ("info", "args_kwargs"), - make_info_args_kwargs_params( - KERNEL_INFOS_MAP[F.convert_dtype_image_tensor], - args_kwargs_fn=lambda info: info.sample_inputs_fn(), - ), -) -@pytest.mark.parametrize("device", cpu_and_cuda()) -def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device): - (input, *other_args), kwargs = args_kwargs.load(device) - dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32) - - output = info.kernel(input, dtype) - - assert output.dtype == dtype - assert output.device == input.device - - @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("num_channels", [1, 3]) def test_normalize_image_tensor_stats(device, num_channels): @@ -581,37 +562,37 @@ def assert_samples_from_standard_normal(t): assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std)) -class TestClampBoundingBox: +class TestClampBoundingBoxes: @pytest.mark.parametrize( "metadata", [ dict(), dict(format=datapoints.BoundingBoxFormat.XYXY), - dict(spatial_size=(1, 1)), + dict(canvas_size=(1, 1)), ], ) def test_simple_tensor_insufficient_metadata(self, metadata): simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) - with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` has to be passed")): - F.clamp_bounding_box(simple_tensor, **metadata) + with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")): + F.clamp_bounding_boxes(simple_tensor, **metadata) @pytest.mark.parametrize( "metadata", [ dict(format=datapoints.BoundingBoxFormat.XYXY), - dict(spatial_size=(1, 1)), - dict(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(1, 1)), + dict(canvas_size=(1, 1)), + dict(format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1, 1)), ], ) def test_datapoint_explicit_metadata(self, metadata): datapoint = next(make_bounding_boxes()) - with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` must not be passed")): - F.clamp_bounding_box(datapoint, **metadata) + with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")): + F.clamp_bounding_boxes(datapoint, **metadata) -class TestConvertFormatBoundingBox: +class TestConvertFormatBoundingBoxes: @pytest.mark.parametrize( ("inpt", "old_format"), [ @@ -621,19 +602,19 @@ class TestConvertFormatBoundingBox: ) def test_missing_new_format(self, inpt, old_format): with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")): - F.convert_format_bounding_box(inpt, old_format) + F.convert_format_bounding_boxes(inpt, old_format) def test_simple_tensor_insufficient_metadata(self): simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")): - F.convert_format_bounding_box(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH) + F.convert_format_bounding_boxes(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH) def test_datapoint_explicit_metadata(self): datapoint = next(make_bounding_boxes()) with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")): - F.convert_format_bounding_box( + F.convert_format_bounding_boxes( datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH ) @@ -665,163 +646,6 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): return true_matrix -@pytest.mark.parametrize("angle", range(-90, 90, 56)) -@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))]) -def test_correctness_rotate_bounding_box(angle, expand, center): - def _compute_expected_bbox(bbox, angle_, expand_, center_): - affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) - affine_matrix = affine_matrix[:2, :] - - height, width = bbox.spatial_size - bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) - points = np.array( - [ - [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], - [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0], - [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0], - [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], - # image frame - [0.0, 0.0, 1.0], - [0.0, height, 1.0], - [width, height, 1.0], - [width, 0.0, 1.0], - ] - ) - transformed_points = np.matmul(points, affine_matrix.T) - out_bbox = [ - float(np.min(transformed_points[:4, 0])), - float(np.min(transformed_points[:4, 1])), - float(np.max(transformed_points[:4, 0])), - float(np.max(transformed_points[:4, 1])), - ] - if expand_: - tr_x = np.min(transformed_points[4:, 0]) - tr_y = np.min(transformed_points[4:, 1]) - out_bbox[0] -= tr_x - out_bbox[1] -= tr_y - out_bbox[2] -= tr_x - out_bbox[3] -= tr_y - - height = int(height - 2 * tr_y) - width = int(width - 2 * tr_x) - - out_bbox = datapoints.BoundingBox( - out_bbox, - format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=(height, width), - dtype=bbox.dtype, - device=bbox.device, - ) - out_bbox = clamp_bounding_box(convert_format_bounding_box(out_bbox, new_format=bbox.format)) - return out_bbox, (height, width) - - spatial_size = (32, 38) - - for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)): - bboxes_format = bboxes.format - bboxes_spatial_size = bboxes.spatial_size - - output_bboxes, output_spatial_size = F.rotate_bounding_box( - bboxes.as_subclass(torch.Tensor), - format=bboxes_format, - spatial_size=bboxes_spatial_size, - angle=angle, - expand=expand, - center=center, - ) - - center_ = center - if center_ is None: - center_ = [s * 0.5 for s in bboxes_spatial_size[::-1]] - - if bboxes.ndim < 2: - bboxes = [bboxes] - - expected_bboxes = [] - for bbox in bboxes: - bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) - expected_bbox, expected_spatial_size = _compute_expected_bbox(bbox, -angle, expand, center_) - expected_bboxes.append(expected_bbox) - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] - torch.testing.assert_close(output_bboxes, expected_bboxes, atol=1, rtol=0) - torch.testing.assert_close(output_spatial_size, expected_spatial_size, atol=1, rtol=0) - - -@pytest.mark.parametrize("device", cpu_and_cuda()) -@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2 -def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): - # Check transformation against known expected output - format = datapoints.BoundingBoxFormat.XYXY - spatial_size = (64, 64) - # xyxy format - in_boxes = [ - [1, 1, 5, 5], - [1, spatial_size[0] - 6, 5, spatial_size[0] - 2], - [spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2], - [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], - ] - in_boxes = torch.tensor(in_boxes, dtype=torch.float64, device=device) - # Tested parameters - angle = 45 - center = None if expand else [12, 23] - - # # Expected bboxes computed using Detectron2: - # from detectron2.data.transforms import RotationTransform, AugmentationList - # from detectron2.data.transforms import AugInput - # import cv2 - # inpt = AugInput(im1, boxes=np.array(in_boxes, dtype="float32")) - # augs = AugmentationList([RotationTransform(*size, angle, expand=expand, center=center, interp=cv2.INTER_NEAREST), ]) - # out = augs(inpt) - # print(inpt.boxes) - if expand: - expected_bboxes = [ - [1.65937957, 42.67157288, 7.31623382, 48.32842712], - [41.96446609, 82.9766594, 47.62132034, 88.63351365], - [82.26955262, 42.67157288, 87.92640687, 48.32842712], - [31.35786438, 31.35786438, 59.64213562, 59.64213562], - ] - else: - expected_bboxes = [ - [-11.33452378, 12.39339828, -5.67766953, 18.05025253], - [28.97056275, 52.69848481, 34.627417, 58.35533906], - [69.27564928, 12.39339828, 74.93250353, 18.05025253], - [18.36396103, 1.07968978, 46.64823228, 29.36396103], - ] - expected_bboxes = clamp_bounding_box( - datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size) - ).tolist() - - output_boxes, _ = F.rotate_bounding_box( - in_boxes, - format=format, - spatial_size=spatial_size, - angle=angle, - expand=expand, - center=center, - ) - - torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) - - -@pytest.mark.parametrize("device", cpu_and_cuda()) -def test_correctness_rotate_segmentation_mask_on_fixed_input(device): - # Check transformation against known expected output and CPU/CUDA devices - - # Create a fixed input segmentation mask with 2 square masks - # in top-left, bottom-left corners - mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device) - mask[0, 2:10, 2:10] = 1 - mask[0, 32 - 9 : 32 - 3, 3:9] = 2 - - # Rotate 90 degrees - expected_mask = torch.rot90(mask, k=1, dims=(-2, -1)) - out_mask = F.rotate_mask(mask, 90, expand=False) - torch.testing.assert_close(out_mask, expected_mask) - - @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "format", @@ -834,7 +658,7 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device): [-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]], ], ) -def test_correctness_crop_bounding_box(device, format, top, left, height, width, expected_bboxes): +def test_correctness_crop_bounding_boxes(device, format, top, left, height, width, expected_bboxes): # Expected bboxes computed using Albumentations: # import numpy as np @@ -849,7 +673,7 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, # expected_bboxes.append(out_box) format = datapoints.BoundingBoxFormat.XYXY - spatial_size = (64, 76) + canvas_size = (64, 76) in_boxes = [ [10.0, 15.0, 25.0, 35.0], [50.0, 5.0, 70.0, 22.0], @@ -857,26 +681,26 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ] in_boxes = torch.tensor(in_boxes, device=device) if format != datapoints.BoundingBoxFormat.XYXY: - in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) + in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) - expected_bboxes = clamp_bounding_box( - datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size) + expected_bboxes = clamp_bounding_boxes( + datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size) ).tolist() - output_boxes, output_spatial_size = F.crop_bounding_box( + output_boxes, output_canvas_size = F.crop_bounding_boxes( in_boxes, format, top, left, - spatial_size[0], - spatial_size[1], + canvas_size[0], + canvas_size[1], ) if format != datapoints.BoundingBoxFormat.XYXY: - output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) + output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) - torch.testing.assert_close(output_spatial_size, spatial_size) + torch.testing.assert_close(output_canvas_size, canvas_size) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -903,7 +727,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): [-5, 5, 35, 45, (32, 34)], ], ) -def test_correctness_resized_crop_bounding_box(device, format, top, left, height, width, size): +def test_correctness_resized_crop_bounding_boxes(device, format, top, left, height, width, size): def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_): # bbox should be xyxy bbox[0] = (bbox[0] - left_) * size_[1] / width_ @@ -913,7 +737,7 @@ def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_): return bbox format = datapoints.BoundingBoxFormat.XYXY - spatial_size = (100, 100) + canvas_size = (100, 100) in_boxes = [ [10.0, 10.0, 20.0, 20.0], [5.0, 10.0, 15.0, 20.0], @@ -923,19 +747,19 @@ def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_): expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size)) expected_bboxes = torch.tensor(expected_bboxes, device=device) - in_boxes = datapoints.BoundingBox( - in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, device=device + in_boxes = datapoints.BoundingBoxes( + in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device ) if format != datapoints.BoundingBoxFormat.XYXY: - in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) + in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) - output_boxes, output_spatial_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size) + output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size) if format != datapoints.BoundingBoxFormat.XYXY: - output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) + output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) torch.testing.assert_close(output_boxes, expected_bboxes) - torch.testing.assert_close(output_spatial_size, size) + torch.testing.assert_close(output_canvas_size, size) def _parse_padding(padding): @@ -952,7 +776,7 @@ def _parse_padding(padding): @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]]) -def test_correctness_pad_bounding_box(device, padding): +def test_correctness_pad_bounding_boxes(device, padding): def _compute_expected_bbox(bbox, padding_): pad_left, pad_up, _, _ = _parse_padding(padding_) @@ -961,41 +785,41 @@ def _compute_expected_bbox(bbox, padding_): bbox = ( bbox.clone() if format == datapoints.BoundingBoxFormat.XYXY - else convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) + else convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) ) bbox[0::2] += pad_left bbox[1::2] += pad_up - bbox = convert_format_bounding_box(bbox, new_format=format) + bbox = convert_format_bounding_boxes(bbox, new_format=format) if bbox.dtype != dtype: # Temporary cast to original dtype # e.g. float32 -> int bbox = bbox.to(dtype) return bbox - def _compute_expected_spatial_size(bbox, padding_): + def _compute_expected_canvas_size(bbox, padding_): pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_) - height, width = bbox.spatial_size + height, width = bbox.canvas_size return height + pad_up + pad_down, width + pad_left + pad_right for bboxes in make_bounding_boxes(): bboxes = bboxes.to(device) bboxes_format = bboxes.format - bboxes_spatial_size = bboxes.spatial_size + bboxes_canvas_size = bboxes.canvas_size - output_boxes, output_spatial_size = F.pad_bounding_box( - bboxes, format=bboxes_format, spatial_size=bboxes_spatial_size, padding=padding + output_boxes, output_canvas_size = F.pad_bounding_boxes( + bboxes, format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding ) - torch.testing.assert_close(output_spatial_size, _compute_expected_spatial_size(bboxes, padding)) + torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding)) if bboxes.ndim < 2 or bboxes.shape[0] == 0: bboxes = [bboxes] expected_bboxes = [] for bbox in bboxes: - bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) + bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size) expected_bboxes.append(_compute_expected_bbox(bbox, padding)) if len(expected_bboxes) > 1: @@ -1025,7 +849,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device): [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]], ], ) -def test_correctness_perspective_bounding_box(device, startpoints, endpoints): +def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): def _compute_expected_bbox(bbox, pcoeffs_): m1 = np.array( [ @@ -1040,7 +864,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): ] ) - bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) + bbox_xyxy = convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) points = np.array( [ [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], @@ -1060,27 +884,27 @@ def _compute_expected_bbox(bbox, pcoeffs_): np.max(transformed_points[:, 1]), ] ) - out_bbox = datapoints.BoundingBox( + out_bbox = datapoints.BoundingBoxes( out_bbox, format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=bbox.spatial_size, + canvas_size=bbox.canvas_size, dtype=bbox.dtype, device=bbox.device, ) - return clamp_bounding_box(convert_format_bounding_box(out_bbox, new_format=bbox.format)) + return clamp_bounding_boxes(convert_format_bounding_boxes(out_bbox, new_format=bbox.format)) - spatial_size = (32, 38) + canvas_size = (32, 38) pcoeffs = _get_perspective_coeffs(startpoints, endpoints) inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints) - for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)): + for bboxes in make_bounding_boxes(canvas_size=canvas_size, extra_dims=((4,),)): bboxes = bboxes.to(device) - output_bboxes = F.perspective_bounding_box( + output_bboxes = F.perspective_bounding_boxes( bboxes.as_subclass(torch.Tensor), format=bboxes.format, - spatial_size=bboxes.spatial_size, + canvas_size=bboxes.canvas_size, startpoints=None, endpoints=None, coefficients=pcoeffs, @@ -1091,7 +915,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): expected_bboxes = [] for bbox in bboxes: - bbox = datapoints.BoundingBox(bbox, format=bboxes.format, spatial_size=bboxes.spatial_size) + bbox = datapoints.BoundingBoxes(bbox, format=bboxes.format, canvas_size=bboxes.canvas_size) expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) if len(expected_bboxes) > 1: expected_bboxes = torch.stack(expected_bboxes) @@ -1105,18 +929,18 @@ def _compute_expected_bbox(bbox, pcoeffs_): "output_size", [(18, 18), [18, 15], (16, 19), [12], [46, 48]], ) -def test_correctness_center_crop_bounding_box(device, output_size): +def test_correctness_center_crop_bounding_boxes(device, output_size): def _compute_expected_bbox(bbox, output_size_): format_ = bbox.format - spatial_size_ = bbox.spatial_size + canvas_size_ = bbox.canvas_size dtype = bbox.dtype - bbox = convert_format_bounding_box(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH) + bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH) if len(output_size_) == 1: output_size_.append(output_size_[-1]) - cy = int(round((spatial_size_[0] - output_size_[0]) * 0.5)) - cx = int(round((spatial_size_[1] - output_size_[1]) * 0.5)) + cy = int(round((canvas_size_[0] - output_size_[0]) * 0.5)) + cx = int(round((canvas_size_[1] - output_size_[1]) * 0.5)) out_bbox = [ bbox[0].item() - cx, bbox[1].item() - cy, @@ -1124,17 +948,17 @@ def _compute_expected_bbox(bbox, output_size_): bbox[3].item(), ] out_bbox = torch.tensor(out_bbox) - out_bbox = convert_format_bounding_box(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_) - out_bbox = clamp_bounding_box(out_bbox, format=format_, spatial_size=output_size) + out_bbox = convert_format_bounding_boxes(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_) + out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size) return out_bbox.to(dtype=dtype, device=bbox.device) for bboxes in make_bounding_boxes(extra_dims=((4,),)): bboxes = bboxes.to(device) bboxes_format = bboxes.format - bboxes_spatial_size = bboxes.spatial_size + bboxes_canvas_size = bboxes.canvas_size - output_boxes, output_spatial_size = F.center_crop_bounding_box( - bboxes, bboxes_format, bboxes_spatial_size, output_size + output_boxes, output_canvas_size = F.center_crop_bounding_boxes( + bboxes, bboxes_format, bboxes_canvas_size, output_size ) if bboxes.ndim < 2: @@ -1142,7 +966,7 @@ def _compute_expected_bbox(bbox, output_size_): expected_bboxes = [] for bbox in bboxes: - bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) + bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size) expected_bboxes.append(_compute_expected_bbox(bbox, output_size)) if len(expected_bboxes) > 1: @@ -1151,7 +975,7 @@ def _compute_expected_bbox(bbox, output_size_): expected_bboxes = expected_bboxes[0] torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0) - torch.testing.assert_close(output_spatial_size, output_size) + torch.testing.assert_close(output_canvas_size, output_size) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -1179,11 +1003,11 @@ def _compute_expected_mask(mask, output_size): # Copied from test/test_functional_tensor.py @pytest.mark.parametrize("device", cpu_and_cuda()) -@pytest.mark.parametrize("spatial_size", ("small", "large")) +@pytest.mark.parametrize("canvas_size", ("small", "large")) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)]) @pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]) -def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, sigma): +def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma): fn = F.gaussian_blur_image_tensor # true_cv2_results = { @@ -1203,7 +1027,7 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt") true_cv2_results = torch.load(p) - if spatial_size == "small": + if canvas_size == "small": tensor = ( torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device) ) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 2130a8cf50a..e9b72161e67 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1,4 +1,5 @@ import contextlib +import decimal import inspect import math import re @@ -16,15 +17,23 @@ assert_no_warnings, cache, cpu_and_cuda, + freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_box, make_detection_mask, make_image, + make_image_pil, + make_image_tensor, make_segmentation_mask, make_video, + needs_cuda, set_rng_seed, ) + +from torch import nn from torch.testing import assert_close +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader, default_collate from torchvision import datapoints from torchvision.transforms._functional_tensor import _max_value as get_max_value @@ -55,18 +64,21 @@ def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs): input_cuda = input.as_subclass(torch.Tensor) input_cpu = input_cuda.to("cpu") - actual = kernel(input_cuda, *args, **kwargs) - expected = kernel(input_cpu, *args, **kwargs) + with freeze_rng_state(): + actual = kernel(input_cuda, *args, **kwargs) + with freeze_rng_state(): + expected = kernel(input_cpu, *args, **kwargs) assert_close(actual, expected, check_device=False, rtol=rtol, atol=atol) @cache -def _script(fn): +def _script(obj): try: - return torch.jit.script(fn) + return torch.jit.script(obj) except Exception as error: - raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error + name = getattr(obj, "__name__", obj.__class__.__name__) + raise AssertionError(f"Trying to `torch.jit.script` '{name}' raised the error above.") from error def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs): @@ -123,6 +135,7 @@ def check_kernel( check_cuda_vs_cpu=True, check_scripted_vs_eager=True, check_batched_vs_unbatched=True, + expect_same_dtype=True, **kwargs, ): initial_input_version = input._version @@ -135,7 +148,8 @@ def check_kernel( # check that no inplace operation happened assert input._version == initial_input_version - assert output.dtype == input.dtype + if expect_same_dtype: + assert output.dtype == input.dtype assert output.device == input.device if check_cuda_vs_cpu: @@ -182,7 +196,7 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs): assert isinstance(output, type(input)) - if isinstance(input, datapoints.BoundingBox): + if isinstance(input, datapoints.BoundingBoxes): assert output.format == input.format @@ -272,7 +286,7 @@ def check_dispatcher_signatures_match(dispatcher, *, kernel, input_type): def _check_transform_v1_compatibility(transform, input): """If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static ``get_params`` method, is scriptable, and the scripted version can be called without error.""" - if not hasattr(transform, "_v1_transform_cls"): + if transform._v1_transform_cls is None: return if type(input) is not torch.Tensor: @@ -292,7 +306,7 @@ def check_transform(transform_cls, input, *args, **kwargs): output = transform(input) assert isinstance(output, type(input)) - if isinstance(input, datapoints.BoundingBox): + if isinstance(input, datapoints.BoundingBoxes): assert output.format == input.format _check_transform_v1_compatibility(transform, input) @@ -308,42 +322,6 @@ def wrapper(input, *args, **kwargs): return wrapper -def make_input(input_type, *, dtype=None, device="cpu", spatial_size=(17, 11), mask_type="segmentation", **kwargs): - if input_type in {torch.Tensor, PIL.Image.Image, datapoints.Image}: - input = make_image(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs) - if input_type is torch.Tensor: - input = input.as_subclass(torch.Tensor) - elif input_type is PIL.Image.Image: - input = F.to_image_pil(input) - elif input_type is datapoints.BoundingBox: - kwargs.setdefault("format", datapoints.BoundingBoxFormat.XYXY) - input = make_bounding_box( - dtype=dtype or torch.float32, - device=device, - spatial_size=spatial_size, - **kwargs, - ) - elif input_type is datapoints.Mask: - if mask_type == "segmentation": - make_mask = make_segmentation_mask - default_dtype = torch.uint8 - elif mask_type == "detection": - make_mask = make_detection_mask - default_dtype = torch.bool - else: - raise ValueError(f"`mask_type` can be `'segmentation'` or `'detection'`, but got {mask_type}.") - input = make_mask(size=spatial_size, dtype=dtype or default_dtype, device=device, **kwargs) - elif input_type is datapoints.Video: - input = make_video(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {input_type} instead." - ) - - return input - - def param_value_parametrization(**kwargs): """Helper function to turn @@ -414,13 +392,13 @@ def assert_warns_antialias_default_value(): yield -def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix): +def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_size, affine_matrix): def transform(bbox): # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 in_dtype = bbox.dtype if not torch.is_floating_point(bbox): bbox = bbox.float() - bbox_xyxy = F.convert_format_bounding_box( + bbox_xyxy = F.convert_format_bounding_boxes( bbox.as_subclass(torch.Tensor), old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, @@ -444,15 +422,15 @@ def transform(bbox): ], dtype=bbox_xyxy.dtype, ) - out_bbox = F.convert_format_bounding_box( + out_bbox = F.convert_format_bounding_boxes( out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True ) # It is important to clamp before casting, especially for CXCYWH format, dtype=int64 - out_bbox = F.clamp_bounding_box(out_bbox, format=format, spatial_size=spatial_size) + out_bbox = F.clamp_bounding_boxes(out_bbox, format=format, canvas_size=canvas_size) out_bbox = out_bbox.to(dtype=in_dtype) return out_bbox - return torch.stack([transform(b) for b in bounding_box.reshape(-1, 4).unbind()]).reshape(bounding_box.shape) + return torch.stack([transform(b) for b in bounding_boxes.reshape(-1, 4).unbind()]).reshape(bounding_boxes.shape) class TestResize: @@ -516,7 +494,7 @@ def test_kernel_image_tensor(self, size, interpolation, use_max_size, antialias, check_kernel( F.resize_image_tensor, - make_input(datapoints.Image, dtype=dtype, device=device, spatial_size=self.INPUT_SIZE), + make_image(self.INPUT_SIZE, dtype=dtype, device=device), size=size, interpolation=interpolation, **max_size_kwarg, @@ -530,69 +508,63 @@ def test_kernel_image_tensor(self, size, interpolation, use_max_size, antialias, @pytest.mark.parametrize("use_max_size", [True, False]) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel_bounding_box(self, format, size, use_max_size, dtype, device): + def test_kernel_bounding_boxes(self, format, size, use_max_size, dtype, device): if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): return - bounding_box = make_input( - datapoints.BoundingBox, dtype=dtype, device=device, format=format, spatial_size=self.INPUT_SIZE + bounding_boxes = make_bounding_box( + format=format, + canvas_size=self.INPUT_SIZE, + dtype=dtype, + device=device, ) check_kernel( - F.resize_bounding_box, - bounding_box, - spatial_size=bounding_box.spatial_size, + F.resize_bounding_boxes, + bounding_boxes, + canvas_size=bounding_boxes.canvas_size, size=size, **max_size_kwarg, check_scripted_vs_eager=not isinstance(size, int), ) - @pytest.mark.parametrize("mask_type", ["segmentation", "detection"]) - def test_kernel_mask(self, mask_type): - check_kernel( - F.resize_mask, - make_input(datapoints.Mask, spatial_size=self.INPUT_SIZE, mask_type=mask_type), - size=self.OUTPUT_SIZES[-1], - ) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask]) + def test_kernel_mask(self, make_mask): + check_kernel(F.resize_mask, make_mask(self.INPUT_SIZE), size=self.OUTPUT_SIZES[-1]) def test_kernel_video(self): - check_kernel( - F.resize_video, - make_input(datapoints.Video, spatial_size=self.INPUT_SIZE), - size=self.OUTPUT_SIZES[-1], - antialias=True, - ) + check_kernel(F.resize_video, make_video(self.INPUT_SIZE), size=self.OUTPUT_SIZES[-1], antialias=True) @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize( - ("input_type", "kernel"), + ("kernel", "make_input"), [ - (torch.Tensor, F.resize_image_tensor), - (PIL.Image.Image, F.resize_image_pil), - (datapoints.Image, F.resize_image_tensor), - (datapoints.BoundingBox, F.resize_bounding_box), - (datapoints.Mask, F.resize_mask), - (datapoints.Video, F.resize_video), + (F.resize_image_tensor, make_image_tensor), + (F.resize_image_pil, make_image_pil), + (F.resize_image_tensor, make_image), + (F.resize_bounding_boxes, make_bounding_box), + (F.resize_mask, make_segmentation_mask), + (F.resize_video, make_video), ], ) - def test_dispatcher(self, size, input_type, kernel): + def test_dispatcher(self, size, kernel, make_input): check_dispatcher( F.resize, kernel, - make_input(input_type, spatial_size=self.INPUT_SIZE), + make_input(self.INPUT_SIZE), size=size, antialias=True, check_scripted_smoke=not isinstance(size, int), ) @pytest.mark.parametrize( - ("input_type", "kernel"), + ("kernel", "input_type"), [ - (torch.Tensor, F.resize_image_tensor), - (PIL.Image.Image, F.resize_image_pil), - (datapoints.Image, F.resize_image_tensor), - (datapoints.BoundingBox, F.resize_bounding_box), - (datapoints.Mask, F.resize_mask), - (datapoints.Video, F.resize_video), + (F.resize_image_tensor, torch.Tensor), + (F.resize_image_pil, PIL.Image.Image), + (F.resize_image_tensor, datapoints.Image), + (F.resize_bounding_boxes, datapoints.BoundingBoxes), + (F.resize_mask, datapoints.Mask), + (F.resize_video, datapoints.Video), ], ) def test_dispatcher_signature(self, kernel, input_type): @@ -601,22 +573,23 @@ def test_dispatcher_signature(self, kernel, input_type): @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( - "input_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video], + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + make_bounding_box, + make_segmentation_mask, + make_detection_mask, + make_video, + ], ) - def test_transform(self, size, device, input_type): - input = make_input(input_type, device=device, spatial_size=self.INPUT_SIZE) - - check_transform( - transforms.Resize, - input, - size=size, - antialias=True, - ) + def test_transform(self, size, device, make_input): + check_transform(transforms.Resize, make_input(self.INPUT_SIZE, device=device), size=size, antialias=True) def _check_output_size(self, input, output, *, size, max_size): - assert tuple(F.get_spatial_size(output)) == self._compute_output_size( - input_size=F.get_spatial_size(input), size=size, max_size=max_size + assert tuple(F.get_size(output)) == self._compute_output_size( + input_size=F.get_size(input), size=size, max_size=max_size ) @pytest.mark.parametrize("size", OUTPUT_SIZES) @@ -629,7 +602,7 @@ def test_image_correctness(self, size, interpolation, use_max_size, fn): if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): return - image = make_input(torch.Tensor, dtype=torch.uint8, device="cpu", spatial_size=self.INPUT_SIZE) + image = make_image(self.INPUT_SIZE, dtype=torch.uint8) actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=True) expected = F.to_image_tensor( @@ -639,54 +612,54 @@ def test_image_correctness(self, size, interpolation, use_max_size, fn): self._check_output_size(image, actual, size=size, **max_size_kwarg) torch.testing.assert_close(actual, expected, atol=1, rtol=0) - def _reference_resize_bounding_box(self, bounding_box, *, size, max_size=None): - old_height, old_width = bounding_box.spatial_size + def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=None): + old_height, old_width = bounding_boxes.canvas_size new_height, new_width = self._compute_output_size( - input_size=bounding_box.spatial_size, size=size, max_size=max_size + input_size=bounding_boxes.canvas_size, size=size, max_size=max_size ) if (old_height, old_width) == (new_height, new_width): - return bounding_box + return bounding_boxes affine_matrix = np.array( [ [new_width / old_width, 0, 0], [0, new_height / old_height, 0], ], - dtype="float64" if bounding_box.dtype == torch.float64 else "float32", + dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32", ) - expected_bboxes = reference_affine_bounding_box_helper( - bounding_box, - format=bounding_box.format, - spatial_size=(new_height, new_width), + expected_bboxes = reference_affine_bounding_boxes_helper( + bounding_boxes, + format=bounding_boxes.format, + canvas_size=(new_height, new_width), affine_matrix=affine_matrix, ) - return datapoints.BoundingBox.wrap_like(bounding_box, expected_bboxes, spatial_size=(new_height, new_width)) + return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes, canvas_size=(new_height, new_width)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("use_max_size", [True, False]) @pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)]) - def test_bounding_box_correctness(self, format, size, use_max_size, fn): + def test_bounding_boxes_correctness(self, format, size, use_max_size, fn): if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): return - bounding_box = make_input(datapoints.BoundingBox, spatial_size=self.INPUT_SIZE) + bounding_boxes = make_bounding_box(format=format, canvas_size=self.INPUT_SIZE) - actual = fn(bounding_box, size=size, **max_size_kwarg) - expected = self._reference_resize_bounding_box(bounding_box, size=size, **max_size_kwarg) + actual = fn(bounding_boxes, size=size, **max_size_kwarg) + expected = self._reference_resize_bounding_boxes(bounding_boxes, size=size, **max_size_kwarg) - self._check_output_size(bounding_box, actual, size=size, **max_size_kwarg) + self._check_output_size(bounding_boxes, actual, size=size, **max_size_kwarg) torch.testing.assert_close(actual, expected) @pytest.mark.parametrize("interpolation", set(transforms.InterpolationMode) - set(INTERPOLATION_MODES)) @pytest.mark.parametrize( - "input_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_video], ) - def test_pil_interpolation_compat_smoke(self, interpolation, input_type): - input = make_input(input_type, spatial_size=self.INPUT_SIZE) + def test_pil_interpolation_compat_smoke(self, interpolation, make_input): + input = make_input(self.INPUT_SIZE) with ( contextlib.nullcontext() @@ -702,16 +675,22 @@ def test_pil_interpolation_compat_smoke(self, interpolation, input_type): def test_dispatcher_pil_antialias_warning(self): with pytest.warns(UserWarning, match="Anti-alias option is always applied for PIL Image input"): - F.resize( - make_input(PIL.Image.Image, spatial_size=self.INPUT_SIZE), size=self.OUTPUT_SIZES[0], antialias=False - ) + F.resize(make_image_pil(self.INPUT_SIZE), size=self.OUTPUT_SIZES[0], antialias=False) @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize( - "input_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video], + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + make_bounding_box, + make_segmentation_mask, + make_detection_mask, + make_video, + ], ) - def test_max_size_error(self, size, input_type): + def test_max_size_error(self, size, make_input): if isinstance(size, int) or len(size) == 1: max_size = (size if isinstance(size, int) else size[0]) - 1 match = "must be strictly greater than the requested size" @@ -721,39 +700,39 @@ def test_max_size_error(self, size, input_type): match = "size should be an int or a sequence of length 1" with pytest.raises(ValueError, match=match): - F.resize(make_input(input_type, spatial_size=self.INPUT_SIZE), size=size, max_size=max_size, antialias=True) + F.resize(make_input(self.INPUT_SIZE), size=size, max_size=max_size, antialias=True) @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES) @pytest.mark.parametrize( - "input_type", - [torch.Tensor, datapoints.Image, datapoints.Video], + "make_input", + [make_image_tensor, make_image, make_video], ) - def test_antialias_warning(self, interpolation, input_type): + def test_antialias_warning(self, interpolation, make_input): with ( assert_warns_antialias_default_value() if interpolation in {transforms.InterpolationMode.BILINEAR, transforms.InterpolationMode.BICUBIC} else assert_no_warnings() ): F.resize( - make_input(input_type, spatial_size=self.INPUT_SIZE), + make_input(self.INPUT_SIZE), size=self.OUTPUT_SIZES[0], interpolation=interpolation, ) @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES) @pytest.mark.parametrize( - "input_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_video], ) - def test_interpolation_int(self, interpolation, input_type): + def test_interpolation_int(self, interpolation, make_input): + input = make_input(self.INPUT_SIZE) + # `InterpolationMode.NEAREST_EXACT` has no proper corresponding integer equivalent. Internally, we map it to # `0` to be the same as `InterpolationMode.NEAREST` for PIL. However, for the tensor backend there is a # difference and thus we don't test it here. - if issubclass(input_type, torch.Tensor) and interpolation is transforms.InterpolationMode.NEAREST_EXACT: + if isinstance(input, torch.Tensor) and interpolation is transforms.InterpolationMode.NEAREST_EXACT: return - input = make_input(input_type, spatial_size=self.INPUT_SIZE) - expected = F.resize(input, size=self.OUTPUT_SIZES[0], interpolation=interpolation, antialias=True) actual = F.resize( input, size=self.OUTPUT_SIZES[0], interpolation=pil_modes_mapping[interpolation], antialias=True @@ -769,13 +748,21 @@ def test_transform_unknown_size_error(self): "size", [min(INPUT_SIZE), [min(INPUT_SIZE)], (min(INPUT_SIZE),), list(INPUT_SIZE), tuple(INPUT_SIZE)] ) @pytest.mark.parametrize( - "input_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video], + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + make_bounding_box, + make_segmentation_mask, + make_detection_mask, + make_video, + ], ) - def test_noop(self, size, input_type): - input = make_input(input_type, spatial_size=self.INPUT_SIZE) + def test_noop(self, size, make_input): + input = make_input(self.INPUT_SIZE) - output = F.resize(input, size=size, antialias=True) + output = F.resize(input, size=F.get_size(input), antialias=True) # This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there # is a good reason to break this, feel free to downgrade to an equality check. @@ -788,133 +775,139 @@ def test_noop(self, size, input_type): assert output is input @pytest.mark.parametrize( - "input_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video], + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + make_bounding_box, + make_segmentation_mask, + make_detection_mask, + make_video, + ], ) - def test_no_regression_5405(self, input_type): + def test_no_regression_5405(self, make_input): # Checks that `max_size` is not ignored if `size == small_edge_size` # See https://github.com/pytorch/vision/issues/5405 - input = make_input(input_type, spatial_size=self.INPUT_SIZE) + input = make_input(self.INPUT_SIZE) - size = min(F.get_spatial_size(input)) + size = min(F.get_size(input)) max_size = size + 1 output = F.resize(input, size=size, max_size=max_size, antialias=True) - assert max(F.get_spatial_size(output)) == max_size + assert max(F.get_size(output)) == max_size class TestHorizontalFlip: @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_image_tensor(self, dtype, device): - check_kernel(F.horizontal_flip_image_tensor, make_input(torch.Tensor, dtype=dtype, device=device)) + check_kernel(F.horizontal_flip_image_tensor, make_image(dtype=dtype, device=device)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel_bounding_box(self, format, dtype, device): - bounding_box = make_input(datapoints.BoundingBox, dtype=dtype, device=device, format=format) + def test_kernel_bounding_boxes(self, format, dtype, device): + bounding_boxes = make_bounding_box(format=format, dtype=dtype, device=device) check_kernel( - F.horizontal_flip_bounding_box, - bounding_box, + F.horizontal_flip_bounding_boxes, + bounding_boxes, format=format, - spatial_size=bounding_box.spatial_size, + canvas_size=bounding_boxes.canvas_size, ) - @pytest.mark.parametrize("mask_type", ["segmentation", "detection"]) - def test_kernel_mask(self, mask_type): - check_kernel(F.horizontal_flip_mask, make_input(datapoints.Mask, mask_type=mask_type)) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask]) + def test_kernel_mask(self, make_mask): + check_kernel(F.horizontal_flip_mask, make_mask()) def test_kernel_video(self): - check_kernel(F.horizontal_flip_video, make_input(datapoints.Video)) + check_kernel(F.horizontal_flip_video, make_video()) @pytest.mark.parametrize( - ("input_type", "kernel"), + ("kernel", "make_input"), [ - (torch.Tensor, F.horizontal_flip_image_tensor), - (PIL.Image.Image, F.horizontal_flip_image_pil), - (datapoints.Image, F.horizontal_flip_image_tensor), - (datapoints.BoundingBox, F.horizontal_flip_bounding_box), - (datapoints.Mask, F.horizontal_flip_mask), - (datapoints.Video, F.horizontal_flip_video), + (F.horizontal_flip_image_tensor, make_image_tensor), + (F.horizontal_flip_image_pil, make_image_pil), + (F.horizontal_flip_image_tensor, make_image), + (F.horizontal_flip_bounding_boxes, make_bounding_box), + (F.horizontal_flip_mask, make_segmentation_mask), + (F.horizontal_flip_video, make_video), ], ) - def test_dispatcher(self, kernel, input_type): - check_dispatcher(F.horizontal_flip, kernel, make_input(input_type)) + def test_dispatcher(self, kernel, make_input): + check_dispatcher(F.horizontal_flip, kernel, make_input()) @pytest.mark.parametrize( - ("input_type", "kernel"), + ("kernel", "input_type"), [ - (torch.Tensor, F.horizontal_flip_image_tensor), - (PIL.Image.Image, F.horizontal_flip_image_pil), - (datapoints.Image, F.horizontal_flip_image_tensor), - (datapoints.BoundingBox, F.horizontal_flip_bounding_box), - (datapoints.Mask, F.horizontal_flip_mask), - (datapoints.Video, F.horizontal_flip_video), + (F.horizontal_flip_image_tensor, torch.Tensor), + (F.horizontal_flip_image_pil, PIL.Image.Image), + (F.horizontal_flip_image_tensor, datapoints.Image), + (F.horizontal_flip_bounding_boxes, datapoints.BoundingBoxes), + (F.horizontal_flip_mask, datapoints.Mask), + (F.horizontal_flip_video, datapoints.Video), ], ) def test_dispatcher_signature(self, kernel, input_type): check_dispatcher_signatures_match(F.horizontal_flip, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( - "input_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_transform(self, input_type, device): - input = make_input(input_type, device=device) - - check_transform(transforms.RandomHorizontalFlip, input, p=1) + def test_transform(self, make_input, device): + check_transform(transforms.RandomHorizontalFlip, make_input(device=device), p=1) @pytest.mark.parametrize( "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)] ) def test_image_correctness(self, fn): - image = make_input(torch.Tensor, dtype=torch.uint8, device="cpu") + image = make_image(dtype=torch.uint8, device="cpu") actual = fn(image) expected = F.to_image_tensor(F.horizontal_flip(F.to_image_pil(image))) torch.testing.assert_close(actual, expected) - def _reference_horizontal_flip_bounding_box(self, bounding_box): + def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes): affine_matrix = np.array( [ - [-1, 0, bounding_box.spatial_size[1]], + [-1, 0, bounding_boxes.canvas_size[1]], [0, 1, 0], ], - dtype="float64" if bounding_box.dtype == torch.float64 else "float32", + dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32", ) - expected_bboxes = reference_affine_bounding_box_helper( - bounding_box, - format=bounding_box.format, - spatial_size=bounding_box.spatial_size, + expected_bboxes = reference_affine_bounding_boxes_helper( + bounding_boxes, + format=bounding_boxes.format, + canvas_size=bounding_boxes.canvas_size, affine_matrix=affine_matrix, ) - return datapoints.BoundingBox.wrap_like(bounding_box, expected_bboxes) + return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize( "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)] ) - def test_bounding_box_correctness(self, format, fn): - bounding_box = make_input(datapoints.BoundingBox, format=format) + def test_bounding_boxes_correctness(self, format, fn): + bounding_boxes = make_bounding_box(format=format) - actual = fn(bounding_box) - expected = self._reference_horizontal_flip_bounding_box(bounding_box) + actual = fn(bounding_boxes) + expected = self._reference_horizontal_flip_bounding_boxes(bounding_boxes) torch.testing.assert_close(actual, expected) @pytest.mark.parametrize( - "input_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_transform_noop(self, input_type, device): - input = make_input(input_type, device=device) + def test_transform_noop(self, make_input, device): + input = make_input(device=device) transform = transforms.RandomHorizontalFlip(p=0) @@ -979,7 +972,7 @@ def test_kernel_image_tensor(self, param, value, dtype, device): value = adapt_fill(value, dtype=dtype) self._check_kernel( F.affine_image_tensor, - make_input(torch.Tensor, dtype=dtype, device=device), + make_image(dtype=dtype, device=device), **{param: value}, check_scripted_vs_eager=not (param in {"shear", "fill"} and isinstance(value, (int, float))), check_cuda_vs_cpu=dict(atol=1, rtol=0) @@ -996,59 +989,59 @@ def test_kernel_image_tensor(self, param, value, dtype, device): @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel_bounding_box(self, param, value, format, dtype, device): - bounding_box = make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device) + def test_kernel_bounding_boxes(self, param, value, format, dtype, device): + bounding_boxes = make_bounding_box(format=format, dtype=dtype, device=device) self._check_kernel( - F.affine_bounding_box, - make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device), + F.affine_bounding_boxes, + bounding_boxes, format=format, - spatial_size=bounding_box.spatial_size, + canvas_size=bounding_boxes.canvas_size, **{param: value}, check_scripted_vs_eager=not (param == "shear" and isinstance(value, (int, float))), ) - @pytest.mark.parametrize("mask_type", ["segmentation", "detection"]) - def test_kernel_mask(self, mask_type): - self._check_kernel(F.affine_mask, make_input(datapoints.Mask, mask_type=mask_type)) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask]) + def test_kernel_mask(self, make_mask): + self._check_kernel(F.affine_mask, make_mask()) def test_kernel_video(self): - self._check_kernel(F.affine_video, make_input(datapoints.Video)) + self._check_kernel(F.affine_video, make_video()) @pytest.mark.parametrize( - ("input_type", "kernel"), + ("kernel", "make_input"), [ - (torch.Tensor, F.affine_image_tensor), - (PIL.Image.Image, F.affine_image_pil), - (datapoints.Image, F.affine_image_tensor), - (datapoints.BoundingBox, F.affine_bounding_box), - (datapoints.Mask, F.affine_mask), - (datapoints.Video, F.affine_video), + (F.affine_image_tensor, make_image_tensor), + (F.affine_image_pil, make_image_pil), + (F.affine_image_tensor, make_image), + (F.affine_bounding_boxes, make_bounding_box), + (F.affine_mask, make_segmentation_mask), + (F.affine_video, make_video), ], ) - def test_dispatcher(self, kernel, input_type): - check_dispatcher(F.affine, kernel, make_input(input_type), **self._MINIMAL_AFFINE_KWARGS) + def test_dispatcher(self, kernel, make_input): + check_dispatcher(F.affine, kernel, make_input(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( - ("input_type", "kernel"), + ("kernel", "input_type"), [ - (torch.Tensor, F.affine_image_tensor), - (PIL.Image.Image, F.affine_image_pil), - (datapoints.Image, F.affine_image_tensor), - (datapoints.BoundingBox, F.affine_bounding_box), - (datapoints.Mask, F.affine_mask), - (datapoints.Video, F.affine_video), + (F.affine_image_tensor, torch.Tensor), + (F.affine_image_pil, PIL.Image.Image), + (F.affine_image_tensor, datapoints.Image), + (F.affine_bounding_boxes, datapoints.BoundingBoxes), + (F.affine_mask, datapoints.Mask), + (F.affine_video, datapoints.Video), ], ) def test_dispatcher_signature(self, kernel, input_type): check_dispatcher_signatures_match(F.affine, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( - "input_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_transform(self, input_type, device): - input = make_input(input_type, device=device) + def test_transform(self, make_input, device): + input = make_input(device=device) check_transform(transforms.RandomAffine, input, **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES) @@ -1062,7 +1055,7 @@ def test_transform(self, input_type, device): ) @pytest.mark.parametrize("fill", CORRECTNESS_FILLS) def test_functional_image_correctness(self, angle, translate, scale, shear, center, interpolation, fill): - image = make_input(torch.Tensor, dtype=torch.uint8, device="cpu") + image = make_image(dtype=torch.uint8, device="cpu") fill = adapt_fill(fill, dtype=torch.uint8) @@ -1099,7 +1092,7 @@ def test_functional_image_correctness(self, angle, translate, scale, shear, cent @pytest.mark.parametrize("fill", CORRECTNESS_FILLS) @pytest.mark.parametrize("seed", list(range(5))) def test_transform_image_correctness(self, center, interpolation, fill, seed): - image = make_input(torch.Tensor, dtype=torch.uint8, device="cpu") + image = make_image(dtype=torch.uint8, device="cpu") fill = adapt_fill(fill, dtype=torch.uint8) @@ -1138,19 +1131,19 @@ def _compute_affine_matrix(self, *, angle, translate, scale, shear, center): true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv))) return true_matrix - def _reference_affine_bounding_box(self, bounding_box, *, angle, translate, scale, shear, center): + def _reference_affine_bounding_boxes(self, bounding_boxes, *, angle, translate, scale, shear, center): if center is None: - center = [s * 0.5 for s in bounding_box.spatial_size[::-1]] + center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]] affine_matrix = self._compute_affine_matrix( angle=angle, translate=translate, scale=scale, shear=shear, center=center ) affine_matrix = affine_matrix[:2, :] - expected_bboxes = reference_affine_bounding_box_helper( - bounding_box, - format=bounding_box.format, - spatial_size=bounding_box.spatial_size, + expected_bboxes = reference_affine_bounding_boxes_helper( + bounding_boxes, + format=bounding_boxes.format, + canvas_size=bounding_boxes.canvas_size, affine_matrix=affine_matrix, ) @@ -1162,19 +1155,19 @@ def _reference_affine_bounding_box(self, bounding_box, *, angle, translate, scal @pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"]) @pytest.mark.parametrize("shear", _CORRECTNESS_AFFINE_KWARGS["shear"]) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) - def test_functional_bounding_box_correctness(self, format, angle, translate, scale, shear, center): - bounding_box = make_input(datapoints.BoundingBox, format=format) + def test_functional_bounding_boxes_correctness(self, format, angle, translate, scale, shear, center): + bounding_boxes = make_bounding_box(format=format) actual = F.affine( - bounding_box, + bounding_boxes, angle=angle, translate=translate, scale=scale, shear=shear, center=center, ) - expected = self._reference_affine_bounding_box( - bounding_box, + expected = self._reference_affine_bounding_boxes( + bounding_boxes, angle=angle, translate=translate, scale=scale, @@ -1187,18 +1180,18 @@ def test_functional_bounding_box_correctness(self, format, angle, translate, sca @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @pytest.mark.parametrize("seed", list(range(5))) - def test_transform_bounding_box_correctness(self, format, center, seed): - bounding_box = make_input(datapoints.BoundingBox, format=format) + def test_transform_bounding_boxes_correctness(self, format, center, seed): + bounding_boxes = make_bounding_box(format=format) transform = transforms.RandomAffine(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, center=center) torch.manual_seed(seed) - params = transform._get_params([bounding_box]) + params = transform._get_params([bounding_boxes]) torch.manual_seed(seed) - actual = transform(bounding_box) + actual = transform(bounding_boxes) - expected = self._reference_affine_bounding_box(bounding_box, **params, center=center) + expected = self._reference_affine_bounding_boxes(bounding_boxes, **params, center=center) torch.testing.assert_close(actual, expected) @@ -1208,8 +1201,8 @@ def test_transform_bounding_box_correctness(self, format, center, seed): @pytest.mark.parametrize("shear", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["shear"]) @pytest.mark.parametrize("seed", list(range(10))) def test_transform_get_params_bounds(self, degrees, translate, scale, shear, seed): - image = make_input(torch.Tensor) - height, width = F.get_spatial_size(image) + image = make_image() + height, width = F.get_size(image) transform = transforms.RandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear) @@ -1289,109 +1282,107 @@ class TestVerticalFlip: @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_image_tensor(self, dtype, device): - check_kernel(F.vertical_flip_image_tensor, make_input(torch.Tensor, dtype=dtype, device=device)) + check_kernel(F.vertical_flip_image_tensor, make_image(dtype=dtype, device=device)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel_bounding_box(self, format, dtype, device): - bounding_box = make_input(datapoints.BoundingBox, dtype=dtype, device=device, format=format) + def test_kernel_bounding_boxes(self, format, dtype, device): + bounding_boxes = make_bounding_box(format=format, dtype=dtype, device=device) check_kernel( - F.vertical_flip_bounding_box, - bounding_box, + F.vertical_flip_bounding_boxes, + bounding_boxes, format=format, - spatial_size=bounding_box.spatial_size, + canvas_size=bounding_boxes.canvas_size, ) - @pytest.mark.parametrize("mask_type", ["segmentation", "detection"]) - def test_kernel_mask(self, mask_type): - check_kernel(F.vertical_flip_mask, make_input(datapoints.Mask, mask_type=mask_type)) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask]) + def test_kernel_mask(self, make_mask): + check_kernel(F.vertical_flip_mask, make_mask()) def test_kernel_video(self): - check_kernel(F.vertical_flip_video, make_input(datapoints.Video)) + check_kernel(F.vertical_flip_video, make_video()) @pytest.mark.parametrize( - ("input_type", "kernel"), + ("kernel", "make_input"), [ - (torch.Tensor, F.vertical_flip_image_tensor), - (PIL.Image.Image, F.vertical_flip_image_pil), - (datapoints.Image, F.vertical_flip_image_tensor), - (datapoints.BoundingBox, F.vertical_flip_bounding_box), - (datapoints.Mask, F.vertical_flip_mask), - (datapoints.Video, F.vertical_flip_video), + (F.vertical_flip_image_tensor, make_image_tensor), + (F.vertical_flip_image_pil, make_image_pil), + (F.vertical_flip_image_tensor, make_image), + (F.vertical_flip_bounding_boxes, make_bounding_box), + (F.vertical_flip_mask, make_segmentation_mask), + (F.vertical_flip_video, make_video), ], ) - def test_dispatcher(self, kernel, input_type): - check_dispatcher(F.vertical_flip, kernel, make_input(input_type)) + def test_dispatcher(self, kernel, make_input): + check_dispatcher(F.vertical_flip, kernel, make_input()) @pytest.mark.parametrize( - ("input_type", "kernel"), + ("kernel", "input_type"), [ - (torch.Tensor, F.vertical_flip_image_tensor), - (PIL.Image.Image, F.vertical_flip_image_pil), - (datapoints.Image, F.vertical_flip_image_tensor), - (datapoints.BoundingBox, F.vertical_flip_bounding_box), - (datapoints.Mask, F.vertical_flip_mask), - (datapoints.Video, F.vertical_flip_video), + (F.vertical_flip_image_tensor, torch.Tensor), + (F.vertical_flip_image_pil, PIL.Image.Image), + (F.vertical_flip_image_tensor, datapoints.Image), + (F.vertical_flip_bounding_boxes, datapoints.BoundingBoxes), + (F.vertical_flip_mask, datapoints.Mask), + (F.vertical_flip_video, datapoints.Video), ], ) def test_dispatcher_signature(self, kernel, input_type): check_dispatcher_signatures_match(F.vertical_flip, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( - "input_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_transform(self, input_type, device): - input = make_input(input_type, device=device) - - check_transform(transforms.RandomVerticalFlip, input, p=1) + def test_transform(self, make_input, device): + check_transform(transforms.RandomVerticalFlip, make_input(device=device), p=1) @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)]) def test_image_correctness(self, fn): - image = make_input(torch.Tensor, dtype=torch.uint8, device="cpu") + image = make_image(dtype=torch.uint8, device="cpu") actual = fn(image) expected = F.to_image_tensor(F.vertical_flip(F.to_image_pil(image))) torch.testing.assert_close(actual, expected) - def _reference_vertical_flip_bounding_box(self, bounding_box): + def _reference_vertical_flip_bounding_boxes(self, bounding_boxes): affine_matrix = np.array( [ [1, 0, 0], - [0, -1, bounding_box.spatial_size[0]], + [0, -1, bounding_boxes.canvas_size[0]], ], - dtype="float64" if bounding_box.dtype == torch.float64 else "float32", + dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32", ) - expected_bboxes = reference_affine_bounding_box_helper( - bounding_box, - format=bounding_box.format, - spatial_size=bounding_box.spatial_size, + expected_bboxes = reference_affine_bounding_boxes_helper( + bounding_boxes, + format=bounding_boxes.format, + canvas_size=bounding_boxes.canvas_size, affine_matrix=affine_matrix, ) - return datapoints.BoundingBox.wrap_like(bounding_box, expected_bboxes) + return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)]) - def test_bounding_box_correctness(self, format, fn): - bounding_box = make_input(datapoints.BoundingBox, format=format) + def test_bounding_boxes_correctness(self, format, fn): + bounding_boxes = make_bounding_box(format=format) - actual = fn(bounding_box) - expected = self._reference_vertical_flip_bounding_box(bounding_box) + actual = fn(bounding_boxes) + expected = self._reference_vertical_flip_bounding_boxes(bounding_boxes) torch.testing.assert_close(actual, expected) @pytest.mark.parametrize( - "input_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_transform_noop(self, input_type, device): - input = make_input(input_type, device=device) + def test_transform_noop(self, make_input, device): + input = make_input(device=device) transform = transforms.RandomVerticalFlip(p=0) @@ -1434,7 +1425,7 @@ def test_kernel_image_tensor(self, param, value, dtype, device): kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"] check_kernel( F.rotate_image_tensor, - make_input(torch.Tensor, dtype=dtype, device=device), + make_image(dtype=dtype, device=device), **kwargs, check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))), ) @@ -1447,65 +1438,65 @@ def test_kernel_image_tensor(self, param, value, dtype, device): @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel_bounding_box(self, param, value, format, dtype, device): + def test_kernel_bounding_boxes(self, param, value, format, dtype, device): kwargs = {param: value} if param != "angle": kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"] - bounding_box = make_input(datapoints.BoundingBox, dtype=dtype, device=device, format=format) + bounding_boxes = make_bounding_box(format=format, dtype=dtype, device=device) check_kernel( - F.rotate_bounding_box, - bounding_box, + F.rotate_bounding_boxes, + bounding_boxes, format=format, - spatial_size=bounding_box.spatial_size, + canvas_size=bounding_boxes.canvas_size, **kwargs, ) - @pytest.mark.parametrize("mask_type", ["segmentation", "detection"]) - def test_kernel_mask(self, mask_type): - check_kernel(F.rotate_mask, make_input(datapoints.Mask, mask_type=mask_type), **self._MINIMAL_AFFINE_KWARGS) + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask]) + def test_kernel_mask(self, make_mask): + check_kernel(F.rotate_mask, make_mask(), **self._MINIMAL_AFFINE_KWARGS) def test_kernel_video(self): - check_kernel(F.rotate_video, make_input(datapoints.Video), **self._MINIMAL_AFFINE_KWARGS) + check_kernel(F.rotate_video, make_video(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( - ("input_type", "kernel"), + ("kernel", "make_input"), [ - (torch.Tensor, F.rotate_image_tensor), - (PIL.Image.Image, F.rotate_image_pil), - (datapoints.Image, F.rotate_image_tensor), - (datapoints.BoundingBox, F.rotate_bounding_box), - (datapoints.Mask, F.rotate_mask), - (datapoints.Video, F.rotate_video), + (F.rotate_image_tensor, make_image_tensor), + (F.rotate_image_pil, make_image_pil), + (F.rotate_image_tensor, make_image), + (F.rotate_bounding_boxes, make_bounding_box), + (F.rotate_mask, make_segmentation_mask), + (F.rotate_video, make_video), ], ) - def test_dispatcher(self, kernel, input_type): - check_dispatcher(F.rotate, kernel, make_input(input_type), **self._MINIMAL_AFFINE_KWARGS) + def test_dispatcher(self, kernel, make_input): + check_dispatcher(F.rotate, kernel, make_input(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( - ("input_type", "kernel"), + ("kernel", "input_type"), [ - (torch.Tensor, F.rotate_image_tensor), - (PIL.Image.Image, F.rotate_image_pil), - (datapoints.Image, F.rotate_image_tensor), - (datapoints.BoundingBox, F.rotate_bounding_box), - (datapoints.Mask, F.rotate_mask), - (datapoints.Video, F.rotate_video), + (F.rotate_image_tensor, torch.Tensor), + (F.rotate_image_pil, PIL.Image.Image), + (F.rotate_image_tensor, datapoints.Image), + (F.rotate_bounding_boxes, datapoints.BoundingBoxes), + (F.rotate_mask, datapoints.Mask), + (F.rotate_video, datapoints.Video), ], ) def test_dispatcher_signature(self, kernel, input_type): check_dispatcher_signatures_match(F.rotate, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( - "input_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_transform(self, input_type, device): - input = make_input(input_type, device=device) - - check_transform(transforms.RandomRotation, input, **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES) + def test_transform(self, make_input, device): + check_transform( + transforms.RandomRotation, make_input(device=device), **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES + ) @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @@ -1515,7 +1506,7 @@ def test_transform(self, input_type, device): @pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("fill", CORRECTNESS_FILLS) def test_functional_image_correctness(self, angle, center, interpolation, expand, fill): - image = make_input(torch.Tensor, dtype=torch.uint8, device="cpu") + image = make_image(dtype=torch.uint8, device="cpu") fill = adapt_fill(fill, dtype=torch.uint8) @@ -1537,7 +1528,7 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand @pytest.mark.parametrize("fill", CORRECTNESS_FILLS) @pytest.mark.parametrize("seed", list(range(5))) def test_transform_image_correctness(self, center, interpolation, expand, fill, seed): - image = make_input(torch.Tensor, dtype=torch.uint8, device="cpu") + image = make_image(dtype=torch.uint8, device="cpu") fill = adapt_fill(fill, dtype=torch.uint8) @@ -1558,13 +1549,13 @@ def test_transform_image_correctness(self, center, interpolation, expand, fill, mae = (actual.float() - expected.float()).abs().mean() assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6 - def _reference_rotate_bounding_box(self, bounding_box, *, angle, expand, center): + def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, center): # FIXME if expand: raise ValueError("This reference currently does not support expand=True") if center is None: - center = [s * 0.5 for s in bounding_box.spatial_size[::-1]] + center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]] a = np.cos(angle * np.pi / 180.0) b = np.sin(angle * np.pi / 180.0) @@ -1575,13 +1566,13 @@ def _reference_rotate_bounding_box(self, bounding_box, *, angle, expand, center) [a, b, cx - cx * a - b * cy], [-b, a, cy + cx * b - a * cy], ], - dtype="float64" if bounding_box.dtype == torch.float64 else "float32", + dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32", ) - expected_bboxes = reference_affine_bounding_box_helper( - bounding_box, - format=bounding_box.format, - spatial_size=bounding_box.spatial_size, + expected_bboxes = reference_affine_bounding_boxes_helper( + bounding_boxes, + format=bounding_boxes.format, + canvas_size=bounding_boxes.canvas_size, affine_matrix=affine_matrix, ) @@ -1592,11 +1583,11 @@ def _reference_rotate_bounding_box(self, bounding_box, *, angle, expand, center) # TODO: add support for expand=True in the reference @pytest.mark.parametrize("expand", [False]) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) - def test_functional_bounding_box_correctness(self, format, angle, expand, center): - bounding_box = make_input(datapoints.BoundingBox, format=format) + def test_functional_bounding_boxes_correctness(self, format, angle, expand, center): + bounding_boxes = make_bounding_box(format=format) - actual = F.rotate(bounding_box, angle=angle, expand=expand, center=center) - expected = self._reference_rotate_bounding_box(bounding_box, angle=angle, expand=expand, center=center) + actual = F.rotate(bounding_boxes, angle=angle, expand=expand, center=center) + expected = self._reference_rotate_bounding_boxes(bounding_boxes, angle=angle, expand=expand, center=center) torch.testing.assert_close(actual, expected) @@ -1605,18 +1596,18 @@ def test_functional_bounding_box_correctness(self, format, angle, expand, center @pytest.mark.parametrize("expand", [False]) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @pytest.mark.parametrize("seed", list(range(5))) - def test_transform_bounding_box_correctness(self, format, expand, center, seed): - bounding_box = make_input(datapoints.BoundingBox, format=format) + def test_transform_bounding_boxes_correctness(self, format, expand, center, seed): + bounding_boxes = make_bounding_box(format=format) transform = transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, expand=expand, center=center) torch.manual_seed(seed) - params = transform._get_params([bounding_box]) + params = transform._get_params([bounding_boxes]) torch.manual_seed(seed) - actual = transform(bounding_box) + actual = transform(bounding_boxes) - expected = self._reference_rotate_bounding_box(bounding_box, **params, expand=expand, center=center) + expected = self._reference_rotate_bounding_boxes(bounding_boxes, **params, expand=expand, center=center) torch.testing.assert_close(actual, expected) @@ -1655,3 +1646,393 @@ def test_transform_negative_degrees_error(self): def test_transform_unknown_fill_error(self): with pytest.raises(TypeError, match="Got inappropriate fill arg"): transforms.RandomAffine(degrees=0, fill="fill") + + +class TestCompose: + class BuiltinTransform(transforms.Transform): + def _transform(self, inpt, params): + return inpt + + class PackedInputTransform(nn.Module): + def forward(self, sample): + assert len(sample) == 2 + return sample + + class UnpackedInputTransform(nn.Module): + def forward(self, image, label): + return image, label + + @pytest.mark.parametrize( + "transform_clss", + [ + [BuiltinTransform], + [PackedInputTransform], + [UnpackedInputTransform], + [BuiltinTransform, BuiltinTransform], + [PackedInputTransform, PackedInputTransform], + [UnpackedInputTransform, UnpackedInputTransform], + [BuiltinTransform, PackedInputTransform, BuiltinTransform], + [BuiltinTransform, UnpackedInputTransform, BuiltinTransform], + [PackedInputTransform, BuiltinTransform, PackedInputTransform], + [UnpackedInputTransform, BuiltinTransform, UnpackedInputTransform], + ], + ) + @pytest.mark.parametrize("unpack", [True, False]) + def test_packed_unpacked(self, transform_clss, unpack): + needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss) + needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss) + assert not (needs_packed_inputs and needs_unpacked_inputs) + + transform = transforms.Compose([cls() for cls in transform_clss]) + + image = make_image() + label = 3 + packed_input = (image, label) + + def call_transform(): + if unpack: + return transform(*packed_input) + else: + return transform(packed_input) + + if needs_unpacked_inputs and not unpack: + with pytest.raises(TypeError, match="missing 1 required positional argument"): + call_transform() + elif needs_packed_inputs and unpack: + with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"): + call_transform() + else: + output = call_transform() + + assert isinstance(output, tuple) and len(output) == 2 + assert output[0] is image + assert output[1] is label + + +class TestToDtype: + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.to_dtype_image_tensor, make_image_tensor), + (F.to_dtype_image_tensor, make_image), + (F.to_dtype_video, make_video), + ], + ) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("scale", (True, False)) + def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, scale): + check_kernel( + kernel, + make_input(dtype=input_dtype, device=device), + expect_same_dtype=input_dtype is output_dtype, + dtype=output_dtype, + scale=scale, + ) + + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.to_dtype_image_tensor, make_image_tensor), + (F.to_dtype_image_tensor, make_image), + (F.to_dtype_video, make_video), + ], + ) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("scale", (True, False)) + def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device, scale): + check_dispatcher( + F.to_dtype, + kernel, + make_input(dtype=input_dtype, device=device), + # TODO: we could leave check_dispatch to True but it currently fails + # in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints. + # We should be able to put this back if we change the dispatch + # mechanism e.g. via https://github.com/pytorch/vision/pull/7733 + check_dispatch=False, + dtype=output_dtype, + scale=scale, + ) + + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image, make_bounding_box, make_segmentation_mask, make_video], + ) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("scale", (True, False)) + @pytest.mark.parametrize("as_dict", (True, False)) + def test_transform(self, make_input, input_dtype, output_dtype, device, scale, as_dict): + input = make_input(dtype=input_dtype, device=device) + if as_dict: + output_dtype = {type(input): output_dtype} + check_transform(transforms.ToDtype, input, dtype=output_dtype, scale=scale) + + def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False): + input_dtype = image.dtype + output_dtype = dtype + + if not scale: + return image.to(dtype) + + if output_dtype == input_dtype: + return image + + def fn(value): + if input_dtype.is_floating_point: + if output_dtype.is_floating_point: + return value + else: + return round(decimal.Decimal(value) * torch.iinfo(output_dtype).max) + else: + input_max_value = torch.iinfo(input_dtype).max + + if output_dtype.is_floating_point: + return float(decimal.Decimal(value) / input_max_value) + else: + output_max_value = torch.iinfo(output_dtype).max + + if input_max_value > output_max_value: + factor = (input_max_value + 1) // (output_max_value + 1) + return value / factor + else: + factor = (output_max_value + 1) // (input_max_value + 1) + return value * factor + + return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype, device=image.device) + + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("scale", (True, False)) + def test_image_correctness(self, input_dtype, output_dtype, device, scale): + if input_dtype.is_floating_point and output_dtype == torch.int64: + pytest.xfail("float to int64 conversion is not supported") + + input = make_image(dtype=input_dtype, device=device) + + out = F.to_dtype(input, dtype=output_dtype, scale=scale) + expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale) + + if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale: + torch.testing.assert_close(out, expected, atol=1, rtol=0) + else: + torch.testing.assert_close(out, expected) + + def was_scaled(self, inpt): + # this assumes the target dtype is float + return inpt.max() <= 1 + + def make_inpt_with_bbox_and_mask(self, make_input): + H, W = 10, 10 + inpt_dtype = torch.uint8 + bbox_dtype = torch.float32 + mask_dtype = torch.bool + sample = { + "inpt": make_input(size=(H, W), dtype=inpt_dtype), + "bbox": make_bounding_box(canvas_size=(H, W), dtype=bbox_dtype), + "mask": make_detection_mask(size=(H, W), dtype=mask_dtype), + } + + return sample, inpt_dtype, bbox_dtype, mask_dtype + + @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video)) + @pytest.mark.parametrize("scale", (True, False)) + def test_dtype_not_a_dict(self, make_input, scale): + # assert only inpt gets transformed when dtype isn't a dict + + sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input) + out = transforms.ToDtype(dtype=torch.float32, scale=scale)(sample) + + assert out["inpt"].dtype != inpt_dtype + assert out["inpt"].dtype == torch.float32 + if scale: + assert self.was_scaled(out["inpt"]) + else: + assert not self.was_scaled(out["inpt"]) + assert out["bbox"].dtype == bbox_dtype + assert out["mask"].dtype == mask_dtype + + @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video)) + def test_others_catch_all_and_none(self, make_input): + # make sure "others" works as a catch-all and that None means no conversion + + sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input) + out = transforms.ToDtype(dtype={datapoints.Mask: torch.int64, "others": None})(sample) + assert out["inpt"].dtype == inpt_dtype + assert out["bbox"].dtype == bbox_dtype + assert out["mask"].dtype != mask_dtype + assert out["mask"].dtype == torch.int64 + + @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video)) + def test_typical_use_case(self, make_input): + # Typical use-case: want to convert dtype and scale for inpt and just dtype for masks. + # This just makes sure we now have a decent API for this + + sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input) + out = transforms.ToDtype( + dtype={type(sample["inpt"]): torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True + )(sample) + assert out["inpt"].dtype != inpt_dtype + assert out["inpt"].dtype == torch.float32 + assert self.was_scaled(out["inpt"]) + assert out["bbox"].dtype == bbox_dtype + assert out["mask"].dtype != mask_dtype + assert out["mask"].dtype == torch.int64 + + @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video)) + def test_errors_warnings(self, make_input): + sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input) + + with pytest.raises(ValueError, match="No dtype was specified for"): + out = transforms.ToDtype(dtype={datapoints.Mask: torch.float32})(sample) + with pytest.warns(UserWarning, match=re.escape("plain `torch.Tensor` will *not* be transformed")): + transforms.ToDtype(dtype={torch.Tensor: torch.float32, datapoints.Image: torch.float32}) + with pytest.warns(UserWarning, match="no scaling will be done"): + out = transforms.ToDtype(dtype={"others": None}, scale=True)(sample) + assert out["inpt"].dtype == inpt_dtype + assert out["bbox"].dtype == bbox_dtype + assert out["mask"].dtype == mask_dtype + + +class TestCutMixMixUp: + class DummyDataset: + def __init__(self, size, num_classes): + self.size = size + self.num_classes = num_classes + assert size < num_classes + + def __getitem__(self, idx): + img = torch.rand(3, 100, 100) + label = idx # This ensures all labels in a batch are unique and makes testing easier + return img, label + + def __len__(self): + return self.size + + @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp]) + def test_supported_input_structure(self, T): + + batch_size = 32 + num_classes = 100 + + dataset = self.DummyDataset(size=batch_size, num_classes=num_classes) + + cutmix_mixup = T(num_classes=num_classes) + + dl = DataLoader(dataset, batch_size=batch_size) + + # Input sanity checks + img, target = next(iter(dl)) + input_img_size = img.shape[-3:] + assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor) + assert target.shape == (batch_size,) + + def check_output(img, target): + assert img.shape == (batch_size, *input_img_size) + assert target.shape == (batch_size, num_classes) + torch.testing.assert_close(target.sum(axis=-1), torch.ones(batch_size)) + num_non_zero_labels = (target != 0).sum(axis=-1) + assert (num_non_zero_labels == 2).all() + + # After Dataloader, as unpacked input + img, target = next(iter(dl)) + assert target.shape == (batch_size,) + img, target = cutmix_mixup(img, target) + check_output(img, target) + + # After Dataloader, as packed input + packed_from_dl = next(iter(dl)) + assert isinstance(packed_from_dl, list) + img, target = cutmix_mixup(packed_from_dl) + check_output(img, target) + + # As collation function. We expect default_collate to be used by users. + def collate_fn_1(batch): + return cutmix_mixup(default_collate(batch)) + + def collate_fn_2(batch): + return cutmix_mixup(*default_collate(batch)) + + for collate_fn in (collate_fn_1, collate_fn_2): + dl = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn) + img, target = next(iter(dl)) + check_output(img, target) + + @needs_cuda + @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp]) + def test_cpu_vs_gpu(self, T): + num_classes = 10 + batch_size = 3 + H, W = 12, 12 + + imgs = torch.rand(batch_size, 3, H, W) + labels = torch.randint(0, num_classes, (batch_size,)) + cutmix_mixup = T(alpha=0.5, num_classes=num_classes) + + _check_kernel_cuda_vs_cpu(cutmix_mixup, imgs, labels, rtol=None, atol=None) + + @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp]) + def test_error(self, T): + + num_classes = 10 + batch_size = 9 + + imgs = torch.rand(batch_size, 3, 12, 12) + cutmix_mixup = T(alpha=0.5, num_classes=num_classes) + + for input_with_bad_type in ( + F.to_pil_image(imgs[0]), + datapoints.Mask(torch.rand(12, 12)), + datapoints.BoundingBoxes(torch.rand(2, 4), format="XYXY", canvas_size=12), + ): + with pytest.raises(ValueError, match="does not support PIL images, "): + cutmix_mixup(input_with_bad_type) + + with pytest.raises(ValueError, match="Could not infer where the labels are"): + cutmix_mixup({"img": imgs, "Nothing_else": 3}) + + with pytest.raises(ValueError, match="labels tensor should be of shape"): + # Note: the error message isn't ideal, but that's because the label heuristic found the img as the label + # It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently + cutmix_mixup(imgs) + + with pytest.raises(ValueError, match="When using the default labels_getter"): + cutmix_mixup(imgs, "not_a_tensor") + + with pytest.raises(ValueError, match="labels tensor should be of shape"): + cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3))) + + with pytest.raises(ValueError, match="Expected a batched input with 4 dims"): + cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,))) + + with pytest.raises(ValueError, match="does not match the batch size of the labels"): + cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,))) + + with pytest.raises(ValueError, match="labels tensor should be of shape"): + # The purpose of this check is more about documenting the current + # behaviour of what happens on a Compose(), rather than actually + # asserting the expected behaviour. We may support Compose() in the + # future, e.g. for 2 consecutive CutMix? + labels = torch.randint(0, num_classes, size=(batch_size,)) + transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels) + + +@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT")) +@pytest.mark.parametrize("sample_type", (tuple, list, dict)) +def test_labels_getter_default_heuristic(key, sample_type): + labels = torch.arange(10) + sample = {key: labels, "another_key": "whatever"} + if sample_type is not dict: + sample = sample_type((None, sample, "whatever_again")) + assert transforms._utils._find_labels_default_heuristic(sample) is labels + + if key.lower() != "labels": + # If "labels" is in the dict (case-insensitive), + # it takes precedence over other keys which would otherwise be a match + d = {key: "something_else", "labels": labels} + assert transforms._utils._find_labels_default_heuristic(d) is labels diff --git a/test/test_transforms_v2_utils.py b/test/test_transforms_v2_utils.py index 198ab39a475..f880dac6c67 100644 --- a/test/test_transforms_v2_utils.py +++ b/test/test_transforms_v2_utils.py @@ -4,36 +4,36 @@ import torch import torchvision.transforms.v2.utils -from common_utils import make_bounding_box, make_detection_mask, make_image +from common_utils import DEFAULT_SIZE, make_bounding_box, make_detection_mask, make_image from torchvision import datapoints from torchvision.transforms.v2.functional import to_image_pil from torchvision.transforms.v2.utils import has_all, has_any -IMAGE = make_image(color_space="RGB") -BOUNDING_BOX = make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size) -MASK = make_detection_mask(size=IMAGE.spatial_size) +IMAGE = make_image(DEFAULT_SIZE, color_space="RGB") +BOUNDING_BOX = make_bounding_box(DEFAULT_SIZE, format=datapoints.BoundingBoxFormat.XYXY) +MASK = make_detection_mask(DEFAULT_SIZE) @pytest.mark.parametrize( ("sample", "types", "expected"), [ ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True), - ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox,), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True), - ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True), - ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox, datapoints.Mask), True), - ((MASK,), (datapoints.Image, datapoints.BoundingBox), False), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes, datapoints.Mask), True), + ((MASK,), (datapoints.Image, datapoints.BoundingBoxes), False), ((BOUNDING_BOX,), (datapoints.Image, datapoints.Mask), False), - ((IMAGE,), (datapoints.BoundingBox, datapoints.Mask), False), + ((IMAGE,), (datapoints.BoundingBoxes, datapoints.Mask), False), ( (IMAGE, BOUNDING_BOX, MASK), - (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), + (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), True, ), - ((), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False), + ((), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), @@ -58,30 +58,30 @@ def test_has_any(sample, types, expected): ("sample", "types", "expected"), [ ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True), - ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox,), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True), - ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True), - ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox, datapoints.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes, datapoints.Mask), True), ( (IMAGE, BOUNDING_BOX, MASK), - (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), + (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), True, ), - ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), False), + ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), False), ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), False), - ((IMAGE, MASK), (datapoints.BoundingBox, datapoints.Mask), False), + ((IMAGE, MASK), (datapoints.BoundingBoxes, datapoints.Mask), False), ( (IMAGE, BOUNDING_BOX, MASK), - (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), + (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), True, ), - ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False), - ((IMAGE, MASK), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False), - ((IMAGE, BOUNDING_BOX), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False), + ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False), + ((IMAGE, MASK), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False), + ((IMAGE, BOUNDING_BOX), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False), ( (IMAGE, BOUNDING_BOX, MASK), - (lambda obj: isinstance(obj, (datapoints.Image, datapoints.BoundingBox, datapoints.Mask)),), + (lambda obj: isinstance(obj, (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask)),), True, ), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), diff --git a/test/transforms_v2_dispatcher_infos.py b/test/transforms_v2_dispatcher_infos.py index 6f61526f382..239954dda68 100644 --- a/test/transforms_v2_dispatcher_infos.py +++ b/test/transforms_v2_dispatcher_infos.py @@ -143,7 +143,7 @@ def fill_sequence_needs_broadcast(args_kwargs): kernels={ datapoints.Image: F.crop_image_tensor, datapoints.Video: F.crop_video, - datapoints.BoundingBox: F.crop_bounding_box, + datapoints.BoundingBoxes: F.crop_bounding_boxes, datapoints.Mask: F.crop_mask, }, pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"), @@ -153,7 +153,7 @@ def fill_sequence_needs_broadcast(args_kwargs): kernels={ datapoints.Image: F.resized_crop_image_tensor, datapoints.Video: F.resized_crop_video, - datapoints.BoundingBox: F.resized_crop_bounding_box, + datapoints.BoundingBoxes: F.resized_crop_bounding_boxes, datapoints.Mask: F.resized_crop_mask, }, pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil), @@ -163,7 +163,7 @@ def fill_sequence_needs_broadcast(args_kwargs): kernels={ datapoints.Image: F.pad_image_tensor, datapoints.Video: F.pad_video, - datapoints.BoundingBox: F.pad_bounding_box, + datapoints.BoundingBoxes: F.pad_bounding_boxes, datapoints.Mask: F.pad_mask, }, pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"), @@ -185,7 +185,7 @@ def fill_sequence_needs_broadcast(args_kwargs): kernels={ datapoints.Image: F.perspective_image_tensor, datapoints.Video: F.perspective_video, - datapoints.BoundingBox: F.perspective_bounding_box, + datapoints.BoundingBoxes: F.perspective_bounding_boxes, datapoints.Mask: F.perspective_mask, }, pil_kernel_info=PILKernelInfo(F.perspective_image_pil), @@ -199,7 +199,7 @@ def fill_sequence_needs_broadcast(args_kwargs): kernels={ datapoints.Image: F.elastic_image_tensor, datapoints.Video: F.elastic_video, - datapoints.BoundingBox: F.elastic_bounding_box, + datapoints.BoundingBoxes: F.elastic_bounding_boxes, datapoints.Mask: F.elastic_mask, }, pil_kernel_info=PILKernelInfo(F.elastic_image_pil), @@ -210,7 +210,7 @@ def fill_sequence_needs_broadcast(args_kwargs): kernels={ datapoints.Image: F.center_crop_image_tensor, datapoints.Video: F.center_crop_video, - datapoints.BoundingBox: F.center_crop_bounding_box, + datapoints.BoundingBoxes: F.center_crop_bounding_boxes, datapoints.Mask: F.center_crop_mask, }, pil_kernel_info=PILKernelInfo(F.center_crop_image_pil), @@ -364,16 +364,6 @@ def fill_sequence_needs_broadcast(args_kwargs): xfail_jit_python_scalar_arg("std"), ], ), - DispatcherInfo( - F.convert_dtype, - kernels={ - datapoints.Image: F.convert_dtype_image_tensor, - datapoints.Video: F.convert_dtype_video, - }, - test_marks=[ - skip_dispatch_datapoint, - ], - ), DispatcherInfo( F.uniform_temporal_subsample, kernels={ @@ -384,15 +374,15 @@ def fill_sequence_needs_broadcast(args_kwargs): ], ), DispatcherInfo( - F.clamp_bounding_box, - kernels={datapoints.BoundingBox: F.clamp_bounding_box}, + F.clamp_bounding_boxes, + kernels={datapoints.BoundingBoxes: F.clamp_bounding_boxes}, test_marks=[ skip_dispatch_datapoint, ], ), DispatcherInfo( - F.convert_format_bounding_box, - kernels={datapoints.BoundingBox: F.convert_format_bounding_box}, + F.convert_format_bounding_boxes, + kernels={datapoints.BoundingBoxes: F.convert_format_bounding_boxes}, test_marks=[ skip_dispatch_datapoint, ], diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index cae8d3157e9..85eb24a806c 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -1,4 +1,3 @@ -import decimal import functools import itertools @@ -11,6 +10,7 @@ from common_utils import ( ArgsKwargs, combinations_grid, + DEFAULT_PORTRAIT_SPATIAL_SIZE, get_num_channels, ImageLoader, InfoBase, @@ -26,7 +26,6 @@ mark_framework_limitation, TestMark, ) -from torch.utils._pytree import tree_map from torchvision import datapoints from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding @@ -185,13 +184,13 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs): return other_args, dict(kwargs, fill=fill) -def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix): - def transform(bbox, affine_matrix_, format_, spatial_size_): +def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_size, affine_matrix): + def transform(bbox, affine_matrix_, format_, canvas_size_): # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 in_dtype = bbox.dtype if not torch.is_floating_point(bbox): bbox = bbox.float() - bbox_xyxy = F.convert_format_bounding_box( + bbox_xyxy = F.convert_format_bounding_boxes( bbox.as_subclass(torch.Tensor), old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, @@ -215,18 +214,18 @@ def transform(bbox, affine_matrix_, format_, spatial_size_): ], dtype=bbox_xyxy.dtype, ) - out_bbox = F.convert_format_bounding_box( + out_bbox = F.convert_format_bounding_boxes( out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True ) # It is important to clamp before casting, especially for CXCYWH format, dtype=int64 - out_bbox = F.clamp_bounding_box(out_bbox, format=format_, spatial_size=spatial_size_) + out_bbox = F.clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_) out_bbox = out_bbox.to(dtype=in_dtype) return out_bbox - if bounding_box.ndim < 2: - bounding_box = [bounding_box] + if bounding_boxes.ndim < 2: + bounding_boxes = [bounding_boxes] - expected_bboxes = [transform(bbox, affine_matrix, format, spatial_size) for bbox in bounding_box] + expected_bboxes = [transform(bbox, affine_matrix, format, canvas_size) for bbox in bounding_boxes] if len(expected_bboxes) > 1: expected_bboxes = torch.stack(expected_bboxes) else: @@ -235,31 +234,34 @@ def transform(bbox, affine_matrix_, format_, spatial_size_): return expected_bboxes -def sample_inputs_convert_format_bounding_box(): +def sample_inputs_convert_format_bounding_boxes(): formats = list(datapoints.BoundingBoxFormat) - for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats): - yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format) + for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats): + yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format) -def reference_convert_format_bounding_box(bounding_box, old_format, new_format): +def reference_convert_format_bounding_boxes(bounding_boxes, old_format, new_format): return torchvision.ops.box_convert( - bounding_box, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower() - ).to(bounding_box.dtype) + bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower() + ).to(bounding_boxes.dtype) -def reference_inputs_convert_format_bounding_box(): - for args_kwargs in sample_inputs_convert_format_bounding_box(): +def reference_inputs_convert_format_bounding_boxes(): + for args_kwargs in sample_inputs_convert_format_bounding_boxes(): if len(args_kwargs.args[0].shape) == 2: yield args_kwargs KERNEL_INFOS.append( KernelInfo( - F.convert_format_bounding_box, - sample_inputs_fn=sample_inputs_convert_format_bounding_box, - reference_fn=reference_convert_format_bounding_box, - reference_inputs_fn=reference_inputs_convert_format_bounding_box, + F.convert_format_bounding_boxes, + sample_inputs_fn=sample_inputs_convert_format_bounding_boxes, + reference_fn=reference_convert_format_bounding_boxes, + reference_inputs_fn=reference_inputs_convert_format_bounding_boxes, logs_usage=True, + closeness_kwargs={ + (("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0), + }, ), ) @@ -288,15 +290,15 @@ def reference_inputs_crop_image_tensor(): yield ArgsKwargs(image_loader, **params) -def sample_inputs_crop_bounding_box(): - for bounding_box_loader, params in itertools.product( +def sample_inputs_crop_bounding_boxes(): + for bounding_boxes_loader, params in itertools.product( make_bounding_box_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]] ): - yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params) + yield ArgsKwargs(bounding_boxes_loader, format=bounding_boxes_loader.format, **params) def sample_inputs_crop_mask(): - for mask_loader in make_mask_loaders(sizes=[(16, 17)], num_categories=["random"], num_objects=["random"]): + for mask_loader in make_mask_loaders(sizes=[(16, 17)], num_categories=[10], num_objects=[5]): yield ArgsKwargs(mask_loader, top=4, left=3, height=7, width=8) @@ -306,31 +308,31 @@ def reference_inputs_crop_mask(): def sample_inputs_crop_video(): - for video_loader in make_video_loaders(sizes=[(16, 17)], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[(16, 17)], num_frames=[3]): yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8) -def reference_crop_bounding_box(bounding_box, *, format, top, left, height, width): +def reference_crop_bounding_boxes(bounding_boxes, *, format, top, left, height, width): affine_matrix = np.array( [ [1, 0, -left], [0, 1, -top], ], - dtype="float64" if bounding_box.dtype == torch.float64 else "float32", + dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32", ) - spatial_size = (height, width) - expected_bboxes = reference_affine_bounding_box_helper( - bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix + canvas_size = (height, width) + expected_bboxes = reference_affine_bounding_boxes_helper( + bounding_boxes, format=format, canvas_size=canvas_size, affine_matrix=affine_matrix ) - return expected_bboxes, spatial_size + return expected_bboxes, canvas_size -def reference_inputs_crop_bounding_box(): - for bounding_box_loader, params in itertools.product( +def reference_inputs_crop_bounding_boxes(): + for bounding_boxes_loader, params in itertools.product( make_bounding_box_loaders(extra_dims=((), (4,))), [_CROP_PARAMS[0], _CROP_PARAMS[-1]] ): - yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params) + yield ArgsKwargs(bounding_boxes_loader, format=bounding_boxes_loader.format, **params) KERNEL_INFOS.extend( @@ -344,10 +346,10 @@ def reference_inputs_crop_bounding_box(): float32_vs_uint8=True, ), KernelInfo( - F.crop_bounding_box, - sample_inputs_fn=sample_inputs_crop_bounding_box, - reference_fn=reference_crop_bounding_box, - reference_inputs_fn=reference_inputs_crop_bounding_box, + F.crop_bounding_boxes, + sample_inputs_fn=sample_inputs_crop_bounding_boxes, + reference_fn=reference_crop_bounding_boxes, + reference_inputs_fn=reference_inputs_crop_bounding_boxes, ), KernelInfo( F.crop_mask, @@ -404,9 +406,9 @@ def reference_inputs_resized_crop_image_tensor(): ) -def sample_inputs_resized_crop_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders(): - yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **_RESIZED_CROP_PARAMS[0]) +def sample_inputs_resized_crop_bounding_boxes(): + for bounding_boxes_loader in make_bounding_box_loaders(): + yield ArgsKwargs(bounding_boxes_loader, format=bounding_boxes_loader.format, **_RESIZED_CROP_PARAMS[0]) def sample_inputs_resized_crop_mask(): @@ -415,7 +417,7 @@ def sample_inputs_resized_crop_mask(): def sample_inputs_resized_crop_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader, **_RESIZED_CROP_PARAMS[0]) @@ -434,8 +436,8 @@ def sample_inputs_resized_crop_video(): }, ), KernelInfo( - F.resized_crop_bounding_box, - sample_inputs_fn=sample_inputs_resized_crop_bounding_box, + F.resized_crop_bounding_boxes, + sample_inputs_fn=sample_inputs_resized_crop_bounding_boxes, ), KernelInfo( F.resized_crop_mask, @@ -457,7 +459,7 @@ def sample_inputs_resized_crop_video(): def sample_inputs_pad_image_tensor(): make_pad_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32] + make_image_loaders, sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], dtypes=[torch.float32] ) for image_loader, padding in itertools.product( @@ -498,21 +500,21 @@ def reference_inputs_pad_image_tensor(): yield ArgsKwargs(image_loader, fill=fill, **params) -def sample_inputs_pad_bounding_box(): - for bounding_box_loader, padding in itertools.product( +def sample_inputs_pad_bounding_boxes(): + for bounding_boxes_loader, padding in itertools.product( make_bounding_box_loaders(), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]] ): yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - spatial_size=bounding_box_loader.spatial_size, + bounding_boxes_loader, + format=bounding_boxes_loader.format, + canvas_size=bounding_boxes_loader.canvas_size, padding=padding, padding_mode="constant", ) def sample_inputs_pad_mask(): - for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]): + for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_categories=[10], num_objects=[5]): yield ArgsKwargs(mask_loader, padding=[1]) @@ -524,11 +526,11 @@ def reference_inputs_pad_mask(): def sample_inputs_pad_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader, padding=[1]) -def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, padding_mode): +def reference_pad_bounding_boxes(bounding_boxes, *, format, canvas_size, padding, padding_mode): left, right, top, bottom = _parse_pad_padding(padding) @@ -537,26 +539,26 @@ def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, p [1, 0, left], [0, 1, top], ], - dtype="float64" if bounding_box.dtype == torch.float64 else "float32", + dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32", ) - height = spatial_size[0] + top + bottom - width = spatial_size[1] + left + right + height = canvas_size[0] + top + bottom + width = canvas_size[1] + left + right - expected_bboxes = reference_affine_bounding_box_helper( - bounding_box, format=format, spatial_size=(height, width), affine_matrix=affine_matrix + expected_bboxes = reference_affine_bounding_boxes_helper( + bounding_boxes, format=format, canvas_size=(height, width), affine_matrix=affine_matrix ) return expected_bboxes, (height, width) -def reference_inputs_pad_bounding_box(): - for bounding_box_loader, padding in itertools.product( +def reference_inputs_pad_bounding_boxes(): + for bounding_boxes_loader, padding in itertools.product( make_bounding_box_loaders(extra_dims=((), (4,))), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]] ): yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - spatial_size=bounding_box_loader.spatial_size, + bounding_boxes_loader, + format=bounding_boxes_loader.format, + canvas_size=bounding_boxes_loader.canvas_size, padding=padding, padding_mode="constant", ) @@ -589,10 +591,10 @@ def pad_xfail_jit_fill_condition(args_kwargs): ], ), KernelInfo( - F.pad_bounding_box, - sample_inputs_fn=sample_inputs_pad_bounding_box, - reference_fn=reference_pad_bounding_box, - reference_inputs_fn=reference_inputs_pad_bounding_box, + F.pad_bounding_boxes, + sample_inputs_fn=sample_inputs_pad_bounding_boxes, + reference_fn=reference_pad_bounding_boxes, + reference_inputs_fn=reference_inputs_pad_bounding_boxes, test_marks=[ xfail_jit_python_scalar_arg("padding"), ], @@ -620,7 +622,7 @@ def pad_xfail_jit_fill_condition(args_kwargs): def sample_inputs_perspective_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"]): + for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]): for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): yield ArgsKwargs( image_loader, startpoints=None, endpoints=None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0] @@ -653,12 +655,12 @@ def reference_inputs_perspective_image_tensor(): ) -def sample_inputs_perspective_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders(): +def sample_inputs_perspective_bounding_boxes(): + for bounding_boxes_loader in make_bounding_box_loaders(): yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - spatial_size=bounding_box_loader.spatial_size, + bounding_boxes_loader, + format=bounding_boxes_loader.format, + canvas_size=bounding_boxes_loader.canvas_size, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0], @@ -667,12 +669,12 @@ def sample_inputs_perspective_bounding_box(): format = datapoints.BoundingBoxFormat.XYXY loader = make_bounding_box_loader(format=format) yield ArgsKwargs( - loader, format=format, spatial_size=loader.spatial_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS + loader, format=format, canvas_size=loader.canvas_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS ) def sample_inputs_perspective_mask(): - for mask_loader in make_mask_loaders(sizes=["random"]): + for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]): yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0]) yield ArgsKwargs(make_detection_mask_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS) @@ -686,7 +688,7 @@ def reference_inputs_perspective_mask(): def sample_inputs_perspective_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0]) yield ArgsKwargs(make_video_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS) @@ -710,8 +712,8 @@ def sample_inputs_perspective_video(): test_marks=[xfail_jit_python_scalar_arg("fill")], ), KernelInfo( - F.perspective_bounding_box, - sample_inputs_fn=sample_inputs_perspective_bounding_box, + F.perspective_bounding_boxes, + sample_inputs_fn=sample_inputs_perspective_bounding_boxes, closeness_kwargs={ **scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6), **scripted_vs_eager_float64_tolerances("cuda", atol=1e-6, rtol=1e-6), @@ -740,13 +742,13 @@ def sample_inputs_perspective_video(): ) -def _get_elastic_displacement(spatial_size): - return torch.rand(1, *spatial_size, 2) +def _get_elastic_displacement(canvas_size): + return torch.rand(1, *canvas_size, 2) def sample_inputs_elastic_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"]): - displacement = _get_elastic_displacement(image_loader.spatial_size) + for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]): + displacement = _get_elastic_displacement(image_loader.canvas_size) for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): yield ArgsKwargs(image_loader, displacement=displacement, fill=fill) @@ -760,30 +762,30 @@ def reference_inputs_elastic_image_tensor(): F.InterpolationMode.BICUBIC, ], ): - displacement = _get_elastic_displacement(image_loader.spatial_size) + displacement = _get_elastic_displacement(image_loader.canvas_size) for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill) -def sample_inputs_elastic_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders(): - displacement = _get_elastic_displacement(bounding_box_loader.spatial_size) +def sample_inputs_elastic_bounding_boxes(): + for bounding_boxes_loader in make_bounding_box_loaders(): + displacement = _get_elastic_displacement(bounding_boxes_loader.canvas_size) yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - spatial_size=bounding_box_loader.spatial_size, + bounding_boxes_loader, + format=bounding_boxes_loader.format, + canvas_size=bounding_boxes_loader.canvas_size, displacement=displacement, ) def sample_inputs_elastic_mask(): - for mask_loader in make_mask_loaders(sizes=["random"]): + for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]): displacement = _get_elastic_displacement(mask_loader.shape[-2:]) yield ArgsKwargs(mask_loader, displacement=displacement) def sample_inputs_elastic_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): displacement = _get_elastic_displacement(video_loader.shape[-2:]) yield ArgsKwargs(video_loader, displacement=displacement) @@ -802,8 +804,8 @@ def sample_inputs_elastic_video(): test_marks=[xfail_jit_python_scalar_arg("fill")], ), KernelInfo( - F.elastic_bounding_box, - sample_inputs_fn=sample_inputs_elastic_bounding_box, + F.elastic_bounding_boxes, + sample_inputs_fn=sample_inputs_elastic_bounding_boxes, ), KernelInfo( F.elastic_mask, @@ -843,18 +845,18 @@ def reference_inputs_center_crop_image_tensor(): yield ArgsKwargs(image_loader, output_size=output_size) -def sample_inputs_center_crop_bounding_box(): - for bounding_box_loader, output_size in itertools.product(make_bounding_box_loaders(), _CENTER_CROP_OUTPUT_SIZES): +def sample_inputs_center_crop_bounding_boxes(): + for bounding_boxes_loader, output_size in itertools.product(make_bounding_box_loaders(), _CENTER_CROP_OUTPUT_SIZES): yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - spatial_size=bounding_box_loader.spatial_size, + bounding_boxes_loader, + format=bounding_boxes_loader.format, + canvas_size=bounding_boxes_loader.canvas_size, output_size=output_size, ) def sample_inputs_center_crop_mask(): - for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]): + for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_categories=[10], num_objects=[5]): height, width = mask_loader.shape[-2:] yield ArgsKwargs(mask_loader, output_size=(height // 2, width // 2)) @@ -867,7 +869,7 @@ def reference_inputs_center_crop_mask(): def sample_inputs_center_crop_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): height, width = video_loader.shape[-2:] yield ArgsKwargs(video_loader, output_size=(height // 2, width // 2)) @@ -885,8 +887,8 @@ def sample_inputs_center_crop_video(): ], ), KernelInfo( - F.center_crop_bounding_box, - sample_inputs_fn=sample_inputs_center_crop_bounding_box, + F.center_crop_bounding_boxes, + sample_inputs_fn=sample_inputs_center_crop_bounding_boxes, test_marks=[ xfail_jit_python_scalar_arg("output_size"), ], @@ -947,7 +949,7 @@ def sample_inputs_gaussian_blur_video(): def sample_inputs_equalize_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): + for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader) @@ -973,7 +975,7 @@ def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_for image.mul_(torch.iinfo(dtype).max).round_() return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True) - spatial_size = (256, 256) + canvas_size = (256, 256) for dtype, color_space, fn in itertools.product( [torch.uint8], ["GRAY", "RGB"], @@ -1003,12 +1005,12 @@ def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_for ], ], ): - image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype) + image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *canvas_size), dtype=dtype) yield ArgsKwargs(image_loader) def sample_inputs_equalize_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader) @@ -1031,7 +1033,7 @@ def sample_inputs_equalize_video(): def sample_inputs_invert_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): + for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader) @@ -1041,7 +1043,7 @@ def reference_inputs_invert_image_tensor(): def sample_inputs_invert_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader) @@ -1067,7 +1069,7 @@ def sample_inputs_invert_video(): def sample_inputs_posterize_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): + for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0]) @@ -1080,7 +1082,7 @@ def reference_inputs_posterize_image_tensor(): def sample_inputs_posterize_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader, bits=_POSTERIZE_BITS[0]) @@ -1110,7 +1112,7 @@ def _get_solarize_thresholds(dtype): def sample_inputs_solarize_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): + for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype))) @@ -1125,7 +1127,7 @@ def uint8_to_float32_threshold_adapter(other_args, kwargs): def sample_inputs_solarize_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype))) @@ -1149,7 +1151,7 @@ def sample_inputs_solarize_video(): def sample_inputs_autocontrast_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): + for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader) @@ -1159,7 +1161,7 @@ def reference_inputs_autocontrast_image_tensor(): def sample_inputs_autocontrast_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader) @@ -1189,7 +1191,7 @@ def sample_inputs_autocontrast_video(): def sample_inputs_adjust_sharpness_image_tensor(): for image_loader in make_image_loaders( - sizes=["random", (2, 2)], + sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE, (2, 2)], color_spaces=("GRAY", "RGB"), ): yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) @@ -1204,7 +1206,7 @@ def reference_inputs_adjust_sharpness_image_tensor(): def sample_inputs_adjust_sharpness_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) @@ -1228,7 +1230,7 @@ def sample_inputs_adjust_sharpness_video(): def sample_inputs_erase_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"]): + for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]): # FIXME: make the parameters more diverse h, w = 6, 7 v = torch.rand(image_loader.num_channels, h, w) @@ -1236,7 +1238,7 @@ def sample_inputs_erase_image_tensor(): def sample_inputs_erase_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): # FIXME: make the parameters more diverse h, w = 6, 7 v = torch.rand(video_loader.num_channels, h, w) @@ -1261,7 +1263,7 @@ def sample_inputs_erase_video(): def sample_inputs_adjust_brightness_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): + for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) @@ -1274,7 +1276,7 @@ def reference_inputs_adjust_brightness_image_tensor(): def sample_inputs_adjust_brightness_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) @@ -1301,7 +1303,7 @@ def sample_inputs_adjust_brightness_video(): def sample_inputs_adjust_contrast_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): + for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) @@ -1314,7 +1316,7 @@ def reference_inputs_adjust_contrast_image_tensor(): def sample_inputs_adjust_contrast_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) @@ -1353,7 +1355,7 @@ def sample_inputs_adjust_contrast_video(): def sample_inputs_adjust_gamma_image_tensor(): gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] - for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): + for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) @@ -1367,7 +1369,7 @@ def reference_inputs_adjust_gamma_image_tensor(): def sample_inputs_adjust_gamma_video(): gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader, gamma=gamma, gain=gain) @@ -1397,7 +1399,7 @@ def sample_inputs_adjust_gamma_video(): def sample_inputs_adjust_hue_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): + for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) @@ -1410,7 +1412,7 @@ def reference_inputs_adjust_hue_image_tensor(): def sample_inputs_adjust_hue_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) @@ -1439,7 +1441,7 @@ def sample_inputs_adjust_hue_video(): def sample_inputs_adjust_saturation_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")): + for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")): yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) @@ -1452,7 +1454,7 @@ def reference_inputs_adjust_saturation_image_tensor(): def sample_inputs_adjust_saturation_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): yield ArgsKwargs(video_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) @@ -1480,19 +1482,19 @@ def sample_inputs_adjust_saturation_video(): ) -def sample_inputs_clamp_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders(): +def sample_inputs_clamp_bounding_boxes(): + for bounding_boxes_loader in make_bounding_box_loaders(): yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - spatial_size=bounding_box_loader.spatial_size, + bounding_boxes_loader, + format=bounding_boxes_loader.format, + canvas_size=bounding_boxes_loader.canvas_size, ) KERNEL_INFOS.append( KernelInfo( - F.clamp_bounding_box, - sample_inputs_fn=sample_inputs_clamp_bounding_box, + F.clamp_bounding_boxes, + sample_inputs_fn=sample_inputs_clamp_bounding_boxes, logs_usage=True, ) ) @@ -1500,7 +1502,7 @@ def sample_inputs_clamp_bounding_box(): _FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]] -def _get_five_ten_crop_spatial_size(size): +def _get_five_ten_crop_canvas_size(size): if isinstance(size, int): crop_height = crop_width = size elif len(size) == 1: @@ -1513,7 +1515,7 @@ def _get_five_ten_crop_spatial_size(size): def sample_inputs_five_crop_image_tensor(): for size in _FIVE_TEN_CROP_SIZES: for image_loader in make_image_loaders( - sizes=[_get_five_ten_crop_spatial_size(size)], + sizes=[_get_five_ten_crop_canvas_size(size)], color_spaces=["RGB"], dtypes=[torch.float32], ): @@ -1523,21 +1525,21 @@ def sample_inputs_five_crop_image_tensor(): def reference_inputs_five_crop_image_tensor(): for size in _FIVE_TEN_CROP_SIZES: for image_loader in make_image_loaders( - sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8] + sizes=[_get_five_ten_crop_canvas_size(size)], extra_dims=[()], dtypes=[torch.uint8] ): yield ArgsKwargs(image_loader, size=size) def sample_inputs_five_crop_video(): size = _FIVE_TEN_CROP_SIZES[0] - for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_spatial_size(size)]): + for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_canvas_size(size)]): yield ArgsKwargs(video_loader, size=size) def sample_inputs_ten_crop_image_tensor(): for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): for image_loader in make_image_loaders( - sizes=[_get_five_ten_crop_spatial_size(size)], + sizes=[_get_five_ten_crop_canvas_size(size)], color_spaces=["RGB"], dtypes=[torch.float32], ): @@ -1547,14 +1549,14 @@ def sample_inputs_ten_crop_image_tensor(): def reference_inputs_ten_crop_image_tensor(): for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): for image_loader in make_image_loaders( - sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8] + sizes=[_get_five_ten_crop_canvas_size(size)], extra_dims=[()], dtypes=[torch.uint8] ): yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) def sample_inputs_ten_crop_video(): size = _FIVE_TEN_CROP_SIZES[0] - for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_spatial_size(size)]): + for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_canvas_size(size)]): yield ArgsKwargs(video_loader, size=size) @@ -1562,7 +1564,7 @@ def multi_crop_pil_reference_wrapper(pil_kernel): def wrapper(input_tensor, *other_args, **kwargs): output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs) return type(output)( - F.convert_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype) + F.to_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype, scale=True) for output_pil in output ) @@ -1612,7 +1614,7 @@ def wrapper(input_tensor, *other_args, **kwargs): def sample_inputs_normalize_image_tensor(): for image_loader, (mean, std) in itertools.product( - make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]), + make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], dtypes=[torch.float32]), _NORMALIZE_MEANS_STDS, ): yield ArgsKwargs(image_loader, mean=mean, std=std) @@ -1637,7 +1639,7 @@ def reference_inputs_normalize_image_tensor(): def sample_inputs_normalize_video(): mean, std = _NORMALIZE_MEANS_STDS[0] for video_loader in make_video_loaders( - sizes=["random"], color_spaces=["RGB"], num_frames=["random"], dtypes=[torch.float32] + sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], num_frames=[3], dtypes=[torch.float32] ): yield ArgsKwargs(video_loader, mean=mean, std=std) @@ -1663,125 +1665,8 @@ def sample_inputs_normalize_video(): ) -def sample_inputs_convert_dtype_image_tensor(): - for input_dtype, output_dtype in itertools.product( - [torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2 - ): - if input_dtype.is_floating_point and output_dtype == torch.int64: - # conversion cannot be performed safely - continue - - for image_loader in make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[input_dtype]): - yield ArgsKwargs(image_loader, dtype=output_dtype) - - -def reference_convert_dtype_image_tensor(image, dtype=torch.float): - input_dtype = image.dtype - output_dtype = dtype - - if output_dtype == input_dtype: - return image - - def fn(value): - if input_dtype.is_floating_point: - if output_dtype.is_floating_point: - return value - else: - return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max) - else: - input_max_value = torch.iinfo(input_dtype).max - - if output_dtype.is_floating_point: - return float(decimal.Decimal(value) / input_max_value) - else: - output_max_value = torch.iinfo(output_dtype).max - - if input_max_value > output_max_value: - factor = (input_max_value + 1) // (output_max_value + 1) - return value // factor - else: - factor = (output_max_value + 1) // (input_max_value + 1) - return value * factor - - return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype) - - -def reference_inputs_convert_dtype_image_tensor(): - for input_dtype, output_dtype in itertools.product( - [ - torch.uint8, - torch.int16, - torch.int32, - torch.int64, - torch.float16, - torch.float32, - torch.float64, - torch.bfloat16, - ], - repeat=2, - ): - if (input_dtype == torch.float32 and output_dtype in {torch.int32, torch.int64}) or ( - input_dtype == torch.float64 and output_dtype == torch.int64 - ): - continue - - if input_dtype.is_floating_point: - data = [0.0, 0.5, 1.0] - else: - max_value = torch.iinfo(input_dtype).max - data = [0, max_value // 2, max_value] - image = torch.tensor(data, dtype=input_dtype) - - yield ArgsKwargs(image, dtype=output_dtype) - - -def sample_inputs_convert_dtype_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): - yield ArgsKwargs(video_loader) - - -skip_dtype_consistency = TestMark( - ("TestKernels", "test_dtype_and_device_consistency"), - pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"), - condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32), -) - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.convert_dtype_image_tensor, - sample_inputs_fn=sample_inputs_convert_dtype_image_tensor, - reference_fn=reference_convert_dtype_image_tensor, - reference_inputs_fn=reference_inputs_convert_dtype_image_tensor, - test_marks=[ - skip_dtype_consistency, - TestMark( - ("TestKernels", "test_against_reference"), - pytest.mark.xfail(reason="Conversion overflows"), - condition=lambda args_kwargs: ( - args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16} - and not args_kwargs.kwargs["dtype"].is_floating_point - ) - or ( - args_kwargs.args[0].dtype in {torch.int32, torch.int64} - and args_kwargs.kwargs["dtype"] == torch.float16 - ), - ), - ], - ), - KernelInfo( - F.convert_dtype_video, - sample_inputs_fn=sample_inputs_convert_dtype_video, - test_marks=[ - skip_dtype_consistency, - ], - ), - ] -) - - def sample_inputs_uniform_temporal_subsample_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]): + for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[4]): yield ArgsKwargs(video_loader, num_samples=2) @@ -1797,7 +1682,9 @@ def reference_uniform_temporal_subsample_video(x, num_samples): def reference_inputs_uniform_temporal_subsample_video(): - for video_loader in make_video_loaders(sizes=["random"], color_spaces=["RGB"], num_frames=[10]): + for video_loader in make_video_loaders( + sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], num_frames=[10] + ): for num_samples in range(1, video_loader.shape[-4] + 1): yield ArgsKwargs(video_loader, num_samples) diff --git a/torchvision/datapoints/__init__.py b/torchvision/datapoints/__init__.py index c9343048a2a..fb51f0497ea 100644 --- a/torchvision/datapoints/__init__.py +++ b/torchvision/datapoints/__init__.py @@ -1,6 +1,6 @@ from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS -from ._bounding_box import BoundingBox, BoundingBoxFormat +from ._bounding_box import BoundingBoxes, BoundingBoxFormat from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image from ._mask import Mask diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index 11d42f171e4..780a950403c 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -24,13 +24,13 @@ class BoundingBoxFormat(Enum): CXCYWH = "CXCYWH" -class BoundingBox(Datapoint): +class BoundingBoxes(Datapoint): """[BETA] :class:`torch.Tensor` subclass for bounding boxes. Args: data: Any data that can be turned into a tensor with :func:`torch.as_tensor`. format (BoundingBoxFormat, str): Format of the bounding box. - spatial_size (two-tuple of ints): Height and width of the corresponding image or video. + canvas_size (two-tuple of ints): Height and width of the corresponding image or video. dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from ``data``. device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a @@ -40,49 +40,49 @@ class BoundingBox(Datapoint): """ format: BoundingBoxFormat - spatial_size: Tuple[int, int] + canvas_size: Tuple[int, int] @classmethod - def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: Tuple[int, int]) -> BoundingBox: - bounding_box = tensor.as_subclass(cls) - bounding_box.format = format - bounding_box.spatial_size = spatial_size - return bounding_box + def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: + bounding_boxes = tensor.as_subclass(cls) + bounding_boxes.format = format + bounding_boxes.canvas_size = canvas_size + return bounding_boxes def __new__( cls, data: Any, *, format: Union[BoundingBoxFormat, str], - spatial_size: Tuple[int, int], + canvas_size: Tuple[int, int], dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, requires_grad: Optional[bool] = None, - ) -> BoundingBox: + ) -> BoundingBoxes: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if isinstance(format, str): format = BoundingBoxFormat[format.upper()] - return cls._wrap(tensor, format=format, spatial_size=spatial_size) + return cls._wrap(tensor, format=format, canvas_size=canvas_size) @classmethod def wrap_like( cls, - other: BoundingBox, + other: BoundingBoxes, tensor: torch.Tensor, *, format: Optional[BoundingBoxFormat] = None, - spatial_size: Optional[Tuple[int, int]] = None, - ) -> BoundingBox: - """Wrap a :class:`torch.Tensor` as :class:`BoundingBox` from a reference. + canvas_size: Optional[Tuple[int, int]] = None, + ) -> BoundingBoxes: + """Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference. Args: - other (BoundingBox): Reference bounding box. - tensor (Tensor): Tensor to be wrapped as :class:`BoundingBox` + other (BoundingBoxes): Reference bounding box. + tensor (Tensor): Tensor to be wrapped as :class:`BoundingBoxes` format (BoundingBoxFormat, str, optional): Format of the bounding box. If omitted, it is taken from the reference. - spatial_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If + canvas_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If omitted, it is taken from the reference. """ @@ -92,23 +92,23 @@ def wrap_like( return cls._wrap( tensor, format=format if format is not None else other.format, - spatial_size=spatial_size if spatial_size is not None else other.spatial_size, + canvas_size=canvas_size if canvas_size is not None else other.canvas_size, ) def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(format=self.format, spatial_size=self.spatial_size) + return self._make_repr(format=self.format, canvas_size=self.canvas_size) - def horizontal_flip(self) -> BoundingBox: - output = self._F.horizontal_flip_bounding_box( - self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size + def horizontal_flip(self) -> BoundingBoxes: + output = self._F.horizontal_flip_bounding_boxes( + self.as_subclass(torch.Tensor), format=self.format, canvas_size=self.canvas_size ) - return BoundingBox.wrap_like(self, output) + return BoundingBoxes.wrap_like(self, output) - def vertical_flip(self) -> BoundingBox: - output = self._F.vertical_flip_bounding_box( - self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size + def vertical_flip(self) -> BoundingBoxes: + output = self._F.vertical_flip_bounding_boxes( + self.as_subclass(torch.Tensor), format=self.format, canvas_size=self.canvas_size ) - return BoundingBox.wrap_like(self, output) + return BoundingBoxes.wrap_like(self, output) def resize( # type: ignore[override] self, @@ -116,26 +116,26 @@ def resize( # type: ignore[override] interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", - ) -> BoundingBox: - output, spatial_size = self._F.resize_bounding_box( + ) -> BoundingBoxes: + output, canvas_size = self._F.resize_bounding_boxes( self.as_subclass(torch.Tensor), - spatial_size=self.spatial_size, + canvas_size=self.canvas_size, size=size, max_size=max_size, ) - return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) + return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size) - def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: - output, spatial_size = self._F.crop_bounding_box( + def crop(self, top: int, left: int, height: int, width: int) -> BoundingBoxes: + output, canvas_size = self._F.crop_bounding_boxes( self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width ) - return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) + return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size) - def center_crop(self, output_size: List[int]) -> BoundingBox: - output, spatial_size = self._F.center_crop_bounding_box( - self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size, output_size=output_size + def center_crop(self, output_size: List[int]) -> BoundingBoxes: + output, canvas_size = self._F.center_crop_bounding_boxes( + self.as_subclass(torch.Tensor), format=self.format, canvas_size=self.canvas_size, output_size=output_size ) - return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) + return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size) def resized_crop( self, @@ -146,26 +146,26 @@ def resized_crop( size: List[int], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", - ) -> BoundingBox: - output, spatial_size = self._F.resized_crop_bounding_box( + ) -> BoundingBoxes: + output, canvas_size = self._F.resized_crop_bounding_boxes( self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size ) - return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) + return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size) def pad( self, padding: Union[int, Sequence[int]], fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", - ) -> BoundingBox: - output, spatial_size = self._F.pad_bounding_box( + ) -> BoundingBoxes: + output, canvas_size = self._F.pad_bounding_boxes( self.as_subclass(torch.Tensor), format=self.format, - spatial_size=self.spatial_size, + canvas_size=self.canvas_size, padding=padding, padding_mode=padding_mode, ) - return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) + return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size) def rotate( self, @@ -174,16 +174,16 @@ def rotate( expand: bool = False, center: Optional[List[float]] = None, fill: _FillTypeJIT = None, - ) -> BoundingBox: - output, spatial_size = self._F.rotate_bounding_box( + ) -> BoundingBoxes: + output, canvas_size = self._F.rotate_bounding_boxes( self.as_subclass(torch.Tensor), format=self.format, - spatial_size=self.spatial_size, + canvas_size=self.canvas_size, angle=angle, expand=expand, center=center, ) - return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) + return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size) def affine( self, @@ -194,18 +194,18 @@ def affine( interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: _FillTypeJIT = None, center: Optional[List[float]] = None, - ) -> BoundingBox: - output = self._F.affine_bounding_box( + ) -> BoundingBoxes: + output = self._F.affine_bounding_boxes( self.as_subclass(torch.Tensor), self.format, - self.spatial_size, + self.canvas_size, angle, translate=translate, scale=scale, shear=shear, center=center, ) - return BoundingBox.wrap_like(self, output) + return BoundingBoxes.wrap_like(self, output) def perspective( self, @@ -214,24 +214,24 @@ def perspective( interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: _FillTypeJIT = None, coefficients: Optional[List[float]] = None, - ) -> BoundingBox: - output = self._F.perspective_bounding_box( + ) -> BoundingBoxes: + output = self._F.perspective_bounding_boxes( self.as_subclass(torch.Tensor), format=self.format, - spatial_size=self.spatial_size, + canvas_size=self.canvas_size, startpoints=startpoints, endpoints=endpoints, coefficients=coefficients, ) - return BoundingBox.wrap_like(self, output) + return BoundingBoxes.wrap_like(self, output) def elastic( self, displacement: torch.Tensor, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: _FillTypeJIT = None, - ) -> BoundingBox: - output = self._F.elastic_bounding_box( - self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement + ) -> BoundingBoxes: + output = self._F.elastic_bounding_boxes( + self.as_subclass(torch.Tensor), self.format, self.canvas_size, displacement=displacement ) - return BoundingBox.wrap_like(self, output) + return BoundingBoxes.wrap_like(self, output) diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 0dabec58f25..2059a3a18a0 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -138,8 +138,8 @@ def __deepcopy__(self: D, memo: Dict[int, Any]) -> D: # *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad` # attribute is cleared, so we need to refill it before we return. # Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is - # `BoundingBox.format` and `BoundingBox.spatial_size`, which are immutable and thus implicitly deep-copied by - # `BoundingBox.clone()`. + # `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by + # `BoundingBoxes.clone()`. return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value] def horizontal_flip(self) -> Datapoint: diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index d88bc81e62b..f1e7857264a 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -8,7 +8,6 @@ from collections import defaultdict import torch -from torch.utils.data import Dataset from torchvision import datapoints, datasets from torchvision.transforms.v2 import functional as F @@ -44,7 +43,7 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None): the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are preserved. If ``target_keys`` is ommitted, returns only the values for the ``"boxes"`` and ``"labels"``. * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY`` - coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. + coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBoxes` datapoint. * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are preserved. If ``target_keys`` is @@ -56,7 +55,7 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None): a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and ``"labels"``. * :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY`` - coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. + coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBoxes` datapoint. Image classification datasets @@ -98,7 +97,16 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None): f"but got {target_keys}" ) - return VisionDatasetDatapointWrapper(dataset, target_keys) + # Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name + # "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the + # original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks, + # while we can still inject everything that we need. + wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetDatapointWrapper, type(dataset)), {}) + # Since VisionDatasetDatapointWrapper comes before ImageNet in the MRO, calling the class hits + # VisionDatasetDatapointWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of + # ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather + # have the existing instance as attribute on the new object. + return wrapped_dataset_cls(dataset, target_keys) class WrapperFactories(dict): @@ -117,7 +125,7 @@ def decorator(wrapper_factory): WRAPPER_FACTORIES = WrapperFactories() -class VisionDatasetDatapointWrapper(Dataset): +class VisionDatasetDatapointWrapper: def __init__(self, dataset, target_keys): dataset_cls = type(dataset) @@ -333,13 +341,13 @@ def coco_dectection_wrapper_factory(dataset, target_keys): default={"image_id", "boxes", "labels"}, ) - def segmentation_to_mask(segmentation, *, spatial_size): + def segmentation_to_mask(segmentation, *, canvas_size): from pycocotools import mask segmentation = ( - mask.frPyObjects(segmentation, *spatial_size) + mask.frPyObjects(segmentation, *canvas_size) if isinstance(segmentation, dict) - else mask.merge(mask.frPyObjects(segmentation, *spatial_size)) + else mask.merge(mask.frPyObjects(segmentation, *canvas_size)) ) return torch.from_numpy(mask.decode(segmentation)) @@ -351,7 +359,7 @@ def wrapper(idx, sample): if not target: return image, dict(image_id=image_id) - spatial_size = tuple(F.get_spatial_size(image)) + canvas_size = tuple(F.get_size(image)) batched_target = list_of_dicts_to_dict_of_lists(target) target = {} @@ -360,11 +368,11 @@ def wrapper(idx, sample): target["image_id"] = image_id if "boxes" in target_keys: - target["boxes"] = F.convert_format_bounding_box( - datapoints.BoundingBox( + target["boxes"] = F.convert_format_bounding_boxes( + datapoints.BoundingBoxes( batched_target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, - spatial_size=spatial_size, + canvas_size=canvas_size, ), new_format=datapoints.BoundingBoxFormat.XYXY, ) @@ -373,7 +381,7 @@ def wrapper(idx, sample): target["masks"] = datapoints.Mask( torch.stack( [ - segmentation_to_mask(segmentation, spatial_size=spatial_size) + segmentation_to_mask(segmentation, canvas_size=canvas_size) for segmentation in batched_target["segmentation"] ] ), @@ -442,13 +450,13 @@ def wrapper(idx, sample): target = {} if "boxes" in target_keys: - target["boxes"] = datapoints.BoundingBox( + target["boxes"] = datapoints.BoundingBoxes( [ [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bndbox in batched_instances["bndbox"] ], format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=(image.height, image.width), + canvas_size=(image.height, image.width), ) if "labels" in target_keys: @@ -481,11 +489,11 @@ def wrapper(idx, sample): target, target_types=dataset.target_type, type_wrappers={ - "bbox": lambda item: F.convert_format_bounding_box( - datapoints.BoundingBox( + "bbox": lambda item: F.convert_format_bounding_boxes( + datapoints.BoundingBoxes( item, format=datapoints.BoundingBoxFormat.XYWH, - spatial_size=(image.height, image.width), + canvas_size=(image.height, image.width), ), new_format=datapoints.BoundingBoxFormat.XYXY, ), @@ -532,10 +540,10 @@ def wrapper(idx, sample): target = {} if "boxes" in target_keys: - target["boxes"] = datapoints.BoundingBox( + target["boxes"] = datapoints.BoundingBoxes( batched_target["bbox"], format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=(image.height, image.width), + canvas_size=(image.height, image.width), ) if "labels" in target_keys: @@ -628,9 +636,9 @@ def wrapper(idx, sample): target = {key: target[key] for key in target_keys} if "bbox" in target_keys: - target["bbox"] = F.convert_format_bounding_box( - datapoints.BoundingBox( - target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) + target["bbox"] = F.convert_format_bounding_boxes( + datapoints.BoundingBoxes( + target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width) ), new_format=datapoints.BoundingBoxFormat.XYXY, ) diff --git a/torchvision/datapoints/_image.py b/torchvision/datapoints/_image.py index e47a6c10fc3..2ebf4954d02 100644 --- a/torchvision/datapoints/_image.py +++ b/torchvision/datapoints/_image.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Union import PIL.Image import torch @@ -56,14 +56,6 @@ def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image: def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() - @property - def spatial_size(self) -> Tuple[int, int]: - return tuple(self.shape[-2:]) # type: ignore[return-value] - - @property - def num_channels(self) -> int: - return self.shape[-3] - def horizontal_flip(self) -> Image: output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor)) return Image.wrap_like(self, output) diff --git a/torchvision/datapoints/_mask.py b/torchvision/datapoints/_mask.py index 0135d793d32..bc50b30583c 100644 --- a/torchvision/datapoints/_mask.py +++ b/torchvision/datapoints/_mask.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Union import PIL.Image import torch @@ -51,10 +51,6 @@ def wrap_like( ) -> Mask: return cls._wrap(tensor) - @property - def spatial_size(self) -> Tuple[int, int]: - return tuple(self.shape[-2:]) # type: ignore[return-value] - def horizontal_flip(self) -> Mask: output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor)) return Mask.wrap_like(self, output) diff --git a/torchvision/datapoints/_video.py b/torchvision/datapoints/_video.py index a6fbe2bd473..d527a68a4d1 100644 --- a/torchvision/datapoints/_video.py +++ b/torchvision/datapoints/_video.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Union import torch from torchvision.transforms.functional import InterpolationMode @@ -46,18 +46,6 @@ def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video: def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() - @property - def spatial_size(self) -> Tuple[int, int]: - return tuple(self.shape[-2:]) # type: ignore[return-value] - - @property - def num_channels(self) -> int: - return self.shape[-3] - - @property - def num_frames(self) -> int: - return self.shape[-4] - def horizontal_flip(self) -> Video: output = self._F.horizontal_flip_video(self.as_subclass(torch.Tensor)) return Video.wrap_like(self, output) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index e244207a8ed..0999bf7ba6b 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -1,3 +1,4 @@ +import fnmatch import importlib import inspect import sys @@ -6,7 +7,7 @@ from functools import partial from inspect import signature from types import ModuleType -from typing import Any, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union from torch import nn @@ -203,19 +204,43 @@ def wrapper(fn: Callable[..., M]) -> Callable[..., M]: return wrapper -def list_models(module: Optional[ModuleType] = None) -> List[str]: +def list_models( + module: Optional[ModuleType] = None, + include: Union[Iterable[str], str, None] = None, + exclude: Union[Iterable[str], str, None] = None, +) -> List[str]: """ Returns a list with the names of registered models. Args: module (ModuleType, optional): The module from which we want to extract the available models. + include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models. + Filters are passed to `fnmatch `__ to match Unix shell-style + wildcards. In case of many filters, the results is the union of individual filters. + exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models. + Filter are passed to `fnmatch `__ to match Unix shell-style + wildcards. In case of many filters, the results is removal of all the models that match any individual filter. Returns: models (list): A list with the names of available models. """ - models = [ + all_models = { k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__ - ] + } + if include: + models: Set[str] = set() + if isinstance(include, str): + include = [include] + for include_filter in include: + models = models | set(fnmatch.filter(all_models, include_filter)) + else: + models = all_models + + if exclude: + if isinstance(exclude, str): + exclude = [exclude] + for exclude_filter in exclude: + models = models - set(fnmatch.filter(all_models, exclude_filter)) return sorted(models) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index a25bdc1d42c..559db858ac3 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -408,17 +408,9 @@ def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_q # Find the highest quality match available, even if it is low, including ties gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None]) # Example gt_pred_pairs_of_highest_quality: - # tensor([[ 0, 39796], - # [ 1, 32055], - # [ 1, 32070], - # [ 2, 39190], - # [ 2, 40255], - # [ 3, 40390], - # [ 3, 41455], - # [ 4, 45470], - # [ 5, 45325], - # [ 5, 46390]]) - # Each row is a (gt index, prediction index) + # (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]), + # tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390])) + # Each element in the first tensor is a gt index, and each element in second tensor is a prediction index # Note how gt items 1, 2, 3, and 5 each have two ties pred_inds_to_update = gt_pred_pairs_of_highest_quality[1] diff --git a/torchvision/prototype/datasets/README.md b/torchvision/prototype/datasets/README.md new file mode 100644 index 00000000000..79b426caaf3 --- /dev/null +++ b/torchvision/prototype/datasets/README.md @@ -0,0 +1,7 @@ +# Status of prototype datasets + +These prototype datasets are based on [torchdata](https://github.com/pytorch/data)'s datapipes. Torchdata +development [is +paused](https://github.com/pytorch/data/#torchdata-see-note-below-on-current-status) +as of July 2023, so we are not actively maintaining this module. There is no +estimated date for a stable release of these datasets. diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index f3882361638..631de46b2b6 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -6,7 +6,7 @@ import torch from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper -from torchvision.datapoints import BoundingBox +from torchvision.datapoints import BoundingBoxes from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( @@ -112,7 +112,7 @@ def _prepare_sample( image_path=image_path, image=image, ann_path=ann_path, - bounding_box=BoundingBox( + bounding_boxes=BoundingBoxes( ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy", spatial_size=image.spatial_size, diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 2c819468778..9112a80357c 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -4,7 +4,7 @@ import torch from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper -from torchvision.datapoints import BoundingBox +from torchvision.datapoints import BoundingBoxes from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( @@ -137,15 +137,15 @@ def _prepare_sample( path, buffer = image_data image = EncodedImage.from_file(buffer) - (_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = ann_data + (_, identity), (_, attributes), (_, bounding_boxes), (_, landmarks) = ann_data return dict( path=path, image=image, identity=Label(int(identity["identity"])), attributes={attr: value == "1" for attr, value in attributes.items()}, - bounding_box=BoundingBox( - [int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")], + bounding_boxes=BoundingBoxes( + [int(bounding_boxes[key]) for key in ("x_1", "y_1", "width", "height")], format="xywh", spatial_size=image.spatial_size, ), diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 6616b4e3491..abf19acec0d 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -14,7 +14,7 @@ Mapper, UnBatcher, ) -from torchvision.datapoints import BoundingBox, Mask +from torchvision.datapoints import BoundingBoxes, Mask from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( @@ -126,7 +126,7 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st ), areas=torch.as_tensor([ann["area"] for ann in anns]), crowds=torch.as_tensor([ann["iscrowd"] for ann in anns], dtype=torch.bool), - bounding_boxes=BoundingBox( + bounding_boxes=BoundingBoxes( [ann["bbox"] for ann in anns], format="xywh", spatial_size=spatial_size, diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index bc41ba028c5..b301c6ba030 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -15,7 +15,7 @@ Mapper, ) from torchdata.datapipes.map import IterToMapConverter -from torchvision.datapoints import BoundingBox +from torchvision.datapoints import BoundingBoxes from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( @@ -134,11 +134,11 @@ def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str: def _2011_prepare_ann( self, data: Tuple[str, Tuple[List[str], Tuple[str, BinaryIO]]], spatial_size: Tuple[int, int] ) -> Dict[str, Any]: - _, (bounding_box_data, segmentation_data) = data + _, (bounding_boxes_data, segmentation_data) = data segmentation_path, segmentation_buffer = segmentation_data return dict( - bounding_box=BoundingBox( - [float(part) for part in bounding_box_data[1:]], format="xywh", spatial_size=spatial_size + bounding_boxes=BoundingBoxes( + [float(part) for part in bounding_boxes_data[1:]], format="xywh", spatial_size=spatial_size ), segmentation_path=segmentation_path, segmentation=EncodedImage.from_file(segmentation_buffer), @@ -158,7 +158,7 @@ def _2010_prepare_ann( content = read_mat(buffer) return dict( ann_path=path, - bounding_box=BoundingBox( + bounding_boxes=BoundingBoxes( [int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")], format="xyxy", spatial_size=spatial_size, diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index 85116ca3860..34651fcfce3 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper -from torchvision.datapoints import BoundingBox +from torchvision.datapoints import BoundingBoxes from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( @@ -76,7 +76,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[ (path, buffer), csv_info = data label = int(csv_info["ClassId"]) - bounding_box = BoundingBox( + bounding_boxes = BoundingBoxes( [int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")], format="xyxy", spatial_size=(int(csv_info["Height"]), int(csv_info["Width"])), @@ -86,7 +86,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[ "path": path, "image": EncodedImage.from_file(buffer), "label": Label(label, categories=self._categories), - "bounding_box": bounding_box, + "bounding_boxes": bounding_boxes, } def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index a76b2dba270..aefbbede2e3 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -2,7 +2,7 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper -from torchvision.datapoints import BoundingBox +from torchvision.datapoints import BoundingBoxes from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( @@ -90,7 +90,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Tuple[int, int, int, path=path, image=image, label=Label(target[4] - 1, categories=self._categories), - bounding_box=BoundingBox(target[:4], format="xyxy", spatial_size=image.spatial_size), + bounding_boxes=BoundingBoxes(target[:4], format="xyxy", spatial_size=image.spatial_size), ) def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index a13cfb764e4..53dfbd185bc 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -5,7 +5,7 @@ from xml.etree import ElementTree from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper -from torchvision.datapoints import BoundingBox +from torchvision.datapoints import BoundingBoxes from torchvision.datasets import VOCDetection from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource @@ -103,7 +103,7 @@ def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: anns = self._parse_detection_ann(buffer) instances = anns["object"] return dict( - bounding_boxes=BoundingBox( + bounding_boxes=BoundingBoxes( [ [int(instance["bndbox"][part]) for part in ("xmin", "ymin", "xmax", "ymax")] for instance in instances diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 4f8fdef484c..e3a18599806 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,6 +1,6 @@ from ._presets import StereoMatching # usort: skip -from ._augment import RandomCutmix, RandomMixup, SimpleCopyPaste +from ._augment import RandomCutMix, RandomMixUp, SimpleCopyPaste from ._geometry import FixedSizeCrop from ._misc import PermuteDimensions, TransposeDimensions from ._type_conversion import LabelToOneHot diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index d04baf739d1..4da6cfcf9cd 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -11,10 +11,10 @@ from torchvision.transforms.v2._transform import _RandomApplyTransform from torchvision.transforms.v2.functional._geometry import _check_interpolation -from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_spatial_size +from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size -class _BaseMixupCutmix(_RandomApplyTransform): +class _BaseMixUpCutMix(_RandomApplyTransform): def __init__(self, alpha: float, p: float = 0.5) -> None: super().__init__(p=p) self.alpha = alpha @@ -26,7 +26,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: and has_any(flat_inputs, proto_datapoints.OneHotLabel) ): raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.") - if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask, proto_datapoints.Label): + if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBoxes, datapoints.Mask, proto_datapoints.Label): raise TypeError( f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." ) @@ -38,7 +38,7 @@ def _mixup_onehotlabel(self, inpt: proto_datapoints.OneHotLabel, lam: float) -> return proto_datapoints.OneHotLabel.wrap_like(inpt, output) -class RandomMixup(_BaseMixupCutmix): +class RandomMixUp(_BaseMixUpCutMix): def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type] @@ -60,11 +60,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt -class RandomCutmix(_BaseMixupCutmix): +class RandomCutMix(_BaseMixUpCutMix): def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: lam = float(self._dist.sample(())) # type: ignore[arg-type] - H, W = query_spatial_size(flat_inputs) + H, W = query_size(flat_inputs) r_x = torch.randint(W, ()) r_y = torch.randint(H, ()) @@ -175,7 +175,7 @@ def _copy_paste( # There is a similar +1 in other reference implementations: # https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422 xyxy_boxes[:, 2:] += 1 - boxes = F.convert_format_bounding_box( + boxes = F.convert_format_bounding_boxes( xyxy_boxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True ) out_target["boxes"] = torch.cat([boxes, paste_boxes]) @@ -184,7 +184,7 @@ def _copy_paste( out_target["labels"] = torch.cat([labels, paste_labels]) # Check for degenerated boxes and remove them - boxes = F.convert_format_bounding_box( + boxes = F.convert_format_bounding_boxes( out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY ) degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] @@ -201,14 +201,14 @@ def _extract_image_targets( self, flat_sample: List[Any] ) -> Tuple[List[datapoints._TensorImageType], List[Dict[str, Any]]]: # fetch all images, bboxes, masks and labels from unstructured input - # with List[image], List[BoundingBox], List[Mask], List[Label] + # with List[image], List[BoundingBoxes], List[Mask], List[Label] images, bboxes, masks, labels = [], [], [], [] for obj in flat_sample: if isinstance(obj, datapoints.Image) or is_simple_tensor(obj): images.append(obj) elif isinstance(obj, PIL.Image.Image): images.append(F.to_image_tensor(obj)) - elif isinstance(obj, datapoints.BoundingBox): + elif isinstance(obj, datapoints.BoundingBoxes): bboxes.append(obj) elif isinstance(obj, datapoints.Mask): masks.append(obj) @@ -218,7 +218,7 @@ def _extract_image_targets( if not (len(images) == len(bboxes) == len(masks) == len(labels)): raise TypeError( f"{type(self).__name__}() requires input sample to contain equal sized list of Images, " - "BoundingBoxes, Masks and Labels or OneHotLabels." + "BoundingBoxeses, Masks and Labels or OneHotLabels." ) targets = [] @@ -244,8 +244,8 @@ def _insert_outputs( elif is_simple_tensor(obj): flat_sample[i] = output_images[c0] c0 += 1 - elif isinstance(obj, datapoints.BoundingBox): - flat_sample[i] = datapoints.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"]) + elif isinstance(obj, datapoints.BoundingBoxes): + flat_sample[i] = datapoints.BoundingBoxes.wrap_like(obj, output_targets[c1]["boxes"]) c1 += 1 elif isinstance(obj, datapoints.Mask): flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"]) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 8d5cc24d25a..a4023ca2108 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -6,15 +6,15 @@ from torchvision import datapoints from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.transforms.v2 import functional as F, Transform -from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size -from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_box, query_spatial_size +from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size +from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size class FixedSizeCrop(Transform): def __init__( self, size: Union[int, Sequence[int]], - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, padding_mode: str = "constant", ) -> None: super().__init__() @@ -39,14 +39,14 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video." ) - if has_any(flat_inputs, datapoints.BoundingBox) and not has_any(flat_inputs, Label, OneHotLabel): + if has_any(flat_inputs, datapoints.BoundingBoxes) and not has_any(flat_inputs, Label, OneHotLabel): raise TypeError( - f"If a BoundingBox is contained in the input sample, " + f"If a BoundingBoxes is contained in the input sample, " f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel." ) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - height, width = query_spatial_size(flat_inputs) + height, width = query_size(flat_inputs) new_height = min(height, self.crop_height) new_width = min(width, self.crop_width) @@ -61,13 +61,13 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: bounding_boxes: Optional[torch.Tensor] try: - bounding_boxes = query_bounding_box(flat_inputs) + bounding_boxes = query_bounding_boxes(flat_inputs) except ValueError: bounding_boxes = None if needs_crop and bounding_boxes is not None: format = bounding_boxes.format - bounding_boxes, spatial_size = F.crop_bounding_box( + bounding_boxes, canvas_size = F.crop_bounding_boxes( bounding_boxes.as_subclass(torch.Tensor), format=format, top=top, @@ -75,8 +75,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height=new_height, width=new_width, ) - bounding_boxes = F.clamp_bounding_box(bounding_boxes, format=format, spatial_size=spatial_size) - height_and_width = F.convert_format_bounding_box( + bounding_boxes = F.clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size) + height_and_width = F.convert_format_bounding_boxes( bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH )[..., 2:] is_valid = torch.all(height_and_width > 0, dim=-1) @@ -112,14 +112,14 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["is_valid"] is not None: if isinstance(inpt, (Label, OneHotLabel, datapoints.Mask)): inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] - elif isinstance(inpt, datapoints.BoundingBox): - inpt = datapoints.BoundingBox.wrap_like( + elif isinstance(inpt, datapoints.BoundingBoxes): + inpt = datapoints.BoundingBoxes.wrap_like( inpt, - F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, spatial_size=inpt.spatial_size), + F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size), ) if params["needs_pad"]: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) return inpt diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 3a4e6e956f3..51a2ea9074a 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,15 +1,29 @@ +import functools import warnings -from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union +from collections import defaultdict +from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union import torch from torchvision import datapoints from torchvision.transforms.v2 import Transform -from torchvision.transforms.v2._utils import _get_defaultdict from torchvision.transforms.v2.utils import is_simple_tensor +T = TypeVar("T") + + +def _default_arg(value: T) -> T: + return value + + +def _get_defaultdict(default: T) -> Dict[Any, T]: + # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. + # If it were possible, we could replace this with `defaultdict(lambda: default)` + return defaultdict(functools.partial(_default_arg, default)) + + class PermuteDimensions(Transform): _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 6573446a33a..8ce9bee9b4d 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -4,7 +4,7 @@ from ._transform import Transform # usort: skip -from ._augment import RandomErasing +from ._augment import CutMix, MixUp, RandomErasing from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide from ._color import ( ColorJitter, @@ -39,8 +39,17 @@ ScaleJitter, TenCrop, ) -from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype -from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype +from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat +from ._misc import ( + ConvertImageDtype, + GaussianBlur, + Identity, + Lambda, + LinearTransformation, + Normalize, + SanitizeBoundingBoxes, + ToDtype, +) from ._temporal import UniformTemporalSubsample from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 937e3508a87..780ffccf6b2 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -5,11 +5,14 @@ import PIL.Image import torch +from torch.nn.functional import one_hot +from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import datapoints, transforms as _transforms from torchvision.transforms.v2 import functional as F -from ._transform import _RandomApplyTransform -from .utils import is_simple_tensor, query_chw +from ._transform import _RandomApplyTransform, Transform +from ._utils import _parse_labels_getter +from .utils import has_any, is_simple_tensor, query_chw, query_size class RandomErasing(_RandomApplyTransform): @@ -135,3 +138,185 @@ def _transform( inpt = F.erase(inpt, **params, inplace=self.inplace) return inpt + + +class _BaseMixUpCutMix(Transform): + def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None: + super().__init__() + self.alpha = float(alpha) + self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) + + self.num_classes = num_classes + + self._labels_getter = _parse_labels_getter(labels_getter) + + def forward(self, *inputs): + inputs = inputs if len(inputs) > 1 else inputs[0] + flat_inputs, spec = tree_flatten(inputs) + needs_transform_list = self._needs_transform_list(flat_inputs) + + if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBoxes, datapoints.Mask): + raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.") + + labels = self._labels_getter(inputs) + if not isinstance(labels, torch.Tensor): + raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.") + elif labels.ndim != 1: + raise ValueError( + f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead." + ) + + params = { + "labels": labels, + "batch_size": labels.shape[0], + **self._get_params( + [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] + ), + } + + # By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming + # after an image or video. However, we need to handle them in _transform, so we make sure to set them to True + needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True + flat_outputs = [ + self._transform(inpt, params) if needs_transform else inpt + for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) + ] + + return tree_unflatten(flat_outputs, spec) + + def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int): + expected_num_dims = 5 if isinstance(inpt, datapoints.Video) else 4 + if inpt.ndim != expected_num_dims: + raise ValueError( + f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead." + ) + if inpt.shape[0] != batch_size: + raise ValueError( + f"The batch size of the image or video does not match the batch size of the labels: " + f"{inpt.shape[0]} != {batch_size}." + ) + + def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor: + label = one_hot(label, num_classes=self.num_classes) + if not label.dtype.is_floating_point: + label = label.float() + return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam)) + + +class MixUp(_BaseMixUpCutMix): + """[BETA] Apply MixUp to the provided batch of images and labels. + + .. v2betastatus:: MixUp transform + + Paper: `mixup: Beyond Empirical Risk Minimization `_. + + .. note:: + This transform is meant to be used on **batches** of samples, not + individual images. See + :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage + examples. + The sample pairing is deterministic and done by matching consecutive + samples in the batch, so the batch needs to be shuffled (this is an + implementation detail, not a guaranteed convention.) + + In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed + into a tensor of shape ``(batch_size, num_classes)``. + + Args: + alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. + num_classes (int): number of classes in the batch. Used for one-hot-encoding. + labels_getter (callable or "default", optional): indicates how to identify the labels in the input. + By default, this will pick the second parameter a the labels if it's a tensor. This covers the most + common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``. + It can also be a callable that takes the same input as the transform, and returns the labels. + """ + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type] + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + lam = params["lam"] + + if inpt is params["labels"]: + return self._mixup_label(inpt, lam=lam) + elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): + self._check_image_or_video(inpt, batch_size=params["batch_size"]) + + output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) + + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] + + return output + else: + return inpt + + +class CutMix(_BaseMixUpCutMix): + """[BETA] Apply CutMix to the provided batch of images and labels. + + .. v2betastatus:: CutMix transform + + Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features + `_. + + .. note:: + This transform is meant to be used on **batches** of samples, not + individual images. See + :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage + examples. + The sample pairing is deterministic and done by matching consecutive + samples in the batch, so the batch needs to be shuffled (this is an + implementation detail, not a guaranteed convention.) + + In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed + into a tensor of shape ``(batch_size, num_classes)``. + + Args: + alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. + num_classes (int): number of classes in the batch. Used for one-hot-encoding. + labels_getter (callable or "default", optional): indicates how to identify the labels in the input. + By default, this will pick the second parameter a the labels if it's a tensor. This covers the most + common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``. + It can also be a callable that takes the same input as the transform, and returns the labels. + """ + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + lam = float(self._dist.sample(())) # type: ignore[arg-type] + + H, W = query_size(flat_inputs) + + r_x = torch.randint(W, size=(1,)) + r_y = torch.randint(H, size=(1,)) + + r = 0.5 * math.sqrt(1.0 - lam) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + box = (x1, y1, x2, y2) + + lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + return dict(box=box, lam_adjusted=lam_adjusted) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if inpt is params["labels"]: + return self._mixup_label(inpt, lam=params["lam_adjusted"]) + elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): + self._check_image_or_video(inpt, batch_size=params["batch_size"]) + + x1, y1, x2, y2 = params["box"] + rolled = inpt.roll(1, 0) + output = inpt.clone() + output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] + + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] + + return output + else: + return inpt diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 34c0ced43d2..146c8c236ef 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -9,9 +9,9 @@ from torchvision.transforms import _functional_tensor as _FT from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.transforms.v2.functional._geometry import _check_interpolation -from torchvision.transforms.v2.functional._meta import get_spatial_size +from torchvision.transforms.v2.functional._meta import get_size -from ._utils import _setup_fill_arg +from ._utils import _get_fill, _setup_fill_arg from .utils import check_type, is_simple_tensor @@ -20,7 +20,7 @@ def __init__( self, *, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ) -> None: super().__init__() self.interpolation = _check_interpolation(interpolation) @@ -34,7 +34,7 @@ def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, def _flatten_and_extract_image_or_video( self, inputs: Any, - unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask), + unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBoxes, datapoints.Mask), ) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints._ImageType, datapoints._VideoType]]: flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) needs_transform_list = self._needs_transform_list(flat_inputs) @@ -80,9 +80,9 @@ def _apply_image_or_video_transform( transform_id: str, magnitude: float, interpolation: Union[InterpolationMode, int], - fill: Dict[Type, datapoints._FillTypeJIT], + fill: Dict[Union[Type, str], datapoints._FillTypeJIT], ) -> Union[datapoints._ImageType, datapoints._VideoType]: - fill_ = fill[type(image)] + fill_ = _get_fill(fill, type(image)) if transform_id == "Identity": return image @@ -214,7 +214,7 @@ def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.policy = policy @@ -312,7 +312,7 @@ def _get_policies( def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) - height, width = get_spatial_size(image_or_video) + height, width = get_size(image_or_video) policy = self._policies[int(torch.randint(len(self._policies), ()))] @@ -394,7 +394,7 @@ def __init__( magnitude: int = 9, num_magnitude_bins: int = 31, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops @@ -403,7 +403,7 @@ def __init__( def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) - height, width = get_spatial_size(image_or_video) + height, width = get_size(image_or_video) for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -467,14 +467,14 @@ def __init__( self, num_magnitude_bins: int = 31, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ): super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) - height, width = get_spatial_size(image_or_video) + height, width = get_size(image_or_video) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -550,7 +550,7 @@ def __init__( alpha: float = 1.0, all_ops: bool = True, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 @@ -568,7 +568,7 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs) - height, width = get_spatial_size(orig_image_or_video) + height, width = get_size(orig_image_or_video) if isinstance(orig_image_or_video, torch.Tensor): image_or_video = orig_image_or_video diff --git a/torchvision/transforms/v2/_container.py b/torchvision/transforms/v2/_container.py index fffef4157bd..8f591c49707 100644 --- a/torchvision/transforms/v2/_container.py +++ b/torchvision/transforms/v2/_container.py @@ -43,13 +43,16 @@ def __init__(self, transforms: Sequence[Callable]) -> None: super().__init__() if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") + elif not transforms: + raise ValueError("Pass at least one transform") self.transforms = transforms def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] + needs_unpacking = len(inputs) > 1 for transform in self.transforms: - sample = transform(sample) - return sample + outputs = transform(*inputs) + inputs = outputs if needs_unpacking else (outputs,) + return outputs def extra_repr(self) -> str: format_string = [] diff --git a/torchvision/transforms/v2/_deprecated.py b/torchvision/transforms/v2/_deprecated.py index e900e853d2b..1cb135a3062 100644 --- a/torchvision/transforms/v2/_deprecated.py +++ b/torchvision/transforms/v2/_deprecated.py @@ -16,7 +16,7 @@ class ToTensor(Transform): .. warning:: :class:`v2.ToTensor` is deprecated and will be removed in a future release. - Please use instead ``transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])``. + Please use instead ``v2.Compose([transforms.ToImageTensor(), v2.ToDtype(torch.float32, scale=True)])``. This transform does not support torchscript. @@ -40,7 +40,7 @@ class ToTensor(Transform): def __init__(self) -> None: warnings.warn( "The transform `ToTensor()` is deprecated and will be removed in a future release. " - "Instead, please use `transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])`." + "Instead, please use `v2.Compose([transforms.ToImageTensor(), v2.ToDtype(torch.float32, scale=True)])`." ) super().__init__() diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 731d768c2a6..c7a1e39286f 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -17,12 +17,13 @@ _check_padding_arg, _check_padding_mode_arg, _check_sequence_input, + _get_fill, _setup_angle, _setup_fill_arg, _setup_float_or_seq, _setup_size, ) -from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_spatial_size +from .utils import has_all, has_any, is_simple_tensor, query_bounding_boxes, query_size class RandomHorizontalFlip(_RandomApplyTransform): @@ -31,7 +32,7 @@ class RandomHorizontalFlip(_RandomApplyTransform): .. v2betastatus:: RandomHorizontalFlip transform If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -51,7 +52,7 @@ class RandomVerticalFlip(_RandomApplyTransform): .. v2betastatus:: RandomVerticalFlip transform If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -71,7 +72,7 @@ class Resize(Transform): .. v2betastatus:: Resize transform If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -165,7 +166,7 @@ class CenterCrop(Transform): .. v2betastatus:: CenterCrop transform If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -193,7 +194,7 @@ class RandomResizedCrop(Transform): .. v2betastatus:: RandomResizedCrop transform If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -267,7 +268,7 @@ def __init__( self._log_ratio = torch.log(torch.tensor(self.ratio)) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - height, width = query_spatial_size(flat_inputs) + height, width = query_size(flat_inputs) area = height * width log_ratio = self._log_ratio @@ -371,8 +372,8 @@ def _transform( return F.five_crop(inpt, self.size) def _check_inputs(self, flat_inputs: List[Any]) -> None: - if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask): - raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") + if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): + raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") class TenCrop(Transform): @@ -414,8 +415,8 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) self.vertical_flip = vertical_flip def _check_inputs(self, flat_inputs: List[Any]) -> None: - if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask): - raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") + if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): + raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") def _transform( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] @@ -440,7 +441,7 @@ class Pad(Transform): .. v2betastatus:: Pad transform If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -487,7 +488,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: def __init__( self, padding: Union[int, Sequence[int]], - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -504,7 +505,7 @@ def __init__( self.padding_mode = padding_mode def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] @@ -525,7 +526,7 @@ class RandomZoomOut(_RandomApplyTransform): output_height = input_height * r If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -537,12 +538,12 @@ class RandomZoomOut(_RandomApplyTransform): ``Mask`` will be filled with 0. side_range (sequence of floats, optional): tuple of two floats defines minimum and maximum factors to scale the input size. - p (float, optional): probability of the input being flipped. Default value is 0.5 + p (float, optional): probability that the zoom operation will be performed. """ def __init__( self, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, side_range: Sequence[float] = (1.0, 4.0), p: float = 0.5, ) -> None: @@ -558,7 +559,7 @@ def __init__( raise ValueError(f"Invalid canvas side range provided {side_range}.") def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - orig_h, orig_w = query_spatial_size(flat_inputs) + orig_h, orig_w = query_size(flat_inputs) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) canvas_width = int(orig_w * r) @@ -574,7 +575,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(padding=padding) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) return F.pad(inpt, **params, fill=fill) @@ -584,7 +585,7 @@ class RandomRotation(Transform): .. v2betastatus:: RandomRotation transform If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -620,7 +621,7 @@ def __init__( interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, ) -> None: super().__init__() self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) @@ -640,7 +641,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(angle=angle) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) return F.rotate( inpt, **params, @@ -657,7 +658,7 @@ class RandomAffine(Transform): .. v2betastatus:: RandomAffine transform If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -702,7 +703,7 @@ def __init__( scale: Optional[Sequence[float]] = None, shear: Optional[Union[int, float, Sequence[float]]] = None, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() @@ -735,7 +736,7 @@ def __init__( self.center = center def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - height, width = query_spatial_size(flat_inputs) + height, width = query_size(flat_inputs) angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() if self.translate is not None: @@ -762,7 +763,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(angle=angle, translate=translate, scale=scale, shear=shear) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) return F.affine( inpt, **params, @@ -778,7 +779,7 @@ class RandomCrop(Transform): .. v2betastatus:: RandomCrop transform If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -840,7 +841,7 @@ def __init__( size: Union[int, Sequence[int]], padding: Optional[Union[int, Sequence[int]]] = None, pad_if_needed: bool = False, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -859,7 +860,7 @@ def __init__( self.padding_mode = padding_mode def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - padded_height, padded_width = query_spatial_size(flat_inputs) + padded_height, padded_width = query_size(flat_inputs) if self.padding is not None: pad_left, pad_right, pad_top, pad_bottom = self.padding @@ -918,7 +919,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) if params["needs_crop"]: @@ -933,7 +934,7 @@ class RandomPerspective(_RandomApplyTransform): .. v2betastatus:: RandomPerspective transform If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -959,7 +960,7 @@ def __init__( distortion_scale: float = 0.5, p: float = 0.5, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, ) -> None: super().__init__(p=p) @@ -972,7 +973,7 @@ def __init__( self._fill = _setup_fill_arg(fill) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - height, width = query_spatial_size(flat_inputs) + height, width = query_size(flat_inputs) distortion_scale = self.distortion_scale @@ -1002,7 +1003,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(coefficients=perspective_coeffs) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) return F.perspective( inpt, None, @@ -1019,7 +1020,7 @@ class ElasticTransform(Transform): .. v2betastatus:: RandomPerspective transform If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -1061,7 +1062,7 @@ def __init__( alpha: Union[float, Sequence[float]] = 50.0, sigma: Union[float, Sequence[float]] = 5.0, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, ) -> None: super().__init__() self.alpha = _setup_float_or_seq(alpha, "alpha", 2) @@ -1072,7 +1073,7 @@ def __init__( self._fill = _setup_fill_arg(fill) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - size = list(query_spatial_size(flat_inputs)) + size = list(query_size(flat_inputs)) dx = torch.rand([1, 1] + size) * 2 - 1 if self.sigma[0] > 0.0: @@ -1095,7 +1096,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(displacement=displacement) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) return F.elastic( inpt, **params, @@ -1110,15 +1111,15 @@ class RandomIoUCrop(Transform): .. v2betastatus:: RandomIoUCrop transform - This transformation requires an image or video data and ``datapoints.BoundingBox`` in the input. + This transformation requires an image or video data and ``datapoints.BoundingBoxes`` in the input. .. warning:: In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop` - must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBox`, either immediately + must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`, either immediately after or later in the transforms pipeline. If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -1155,7 +1156,7 @@ def __init__( def _check_inputs(self, flat_inputs: List[Any]) -> None: if not ( - has_all(flat_inputs, datapoints.BoundingBox) + has_all(flat_inputs, datapoints.BoundingBoxes) and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor) ): raise TypeError( @@ -1164,8 +1165,8 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: ) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - orig_h, orig_w = query_spatial_size(flat_inputs) - bboxes = query_bounding_box(flat_inputs) + orig_h, orig_w = query_size(flat_inputs) + bboxes = query_bounding_boxes(flat_inputs) while True: # sample an option @@ -1193,7 +1194,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: continue # check for any valid boxes with centers within the crop area - xyxy_bboxes = F.convert_format_bounding_box( + xyxy_bboxes = F.convert_format_bounding_boxes( bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY ) cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) @@ -1220,9 +1221,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) - if isinstance(output, datapoints.BoundingBox): + if isinstance(output, datapoints.BoundingBoxes): # We "mark" the invalid boxes as degenreate, and they can be - # removed by a later call to SanitizeBoundingBox() + # removed by a later call to SanitizeBoundingBoxes() output[~params["is_within_crop_area"]] = 0 return output @@ -1235,7 +1236,7 @@ class ScaleJitter(Transform): .. v2betastatus:: ScaleJitter transform If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -1282,7 +1283,7 @@ def __init__( self.antialias = antialias def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - orig_height, orig_width = query_spatial_size(flat_inputs) + orig_height, orig_width = query_size(flat_inputs) scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale @@ -1301,7 +1302,7 @@ class RandomShortestSize(Transform): .. v2betastatus:: RandomShortestSize transform If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. @@ -1347,7 +1348,7 @@ def __init__( self.antialias = antialias def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - orig_height, orig_width = query_spatial_size(flat_inputs) + orig_height, orig_width = query_size(flat_inputs) min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] r = min_size / min(orig_height, orig_width) @@ -1380,7 +1381,7 @@ class RandomResize(Transform): output_height = size If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, - :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.) + :class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) it can have arbitrary number of leading batch dimensions. For example, the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index b7e2a42259f..f0b62221083 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -1,12 +1,8 @@ from typing import Any, Dict, Union -import torch - -from torchvision import datapoints, transforms as _transforms +from torchvision import datapoints from torchvision.transforms.v2 import functional as F, Transform -from .utils import is_simple_tensor - class ConvertBoundingBoxFormat(Transform): """[BETA] Convert bounding box coordinates to the given ``format``, eg from "CXCYWH" to "XYXY". @@ -19,7 +15,7 @@ class ConvertBoundingBoxFormat(Transform): string values match the enums, e.g. "XYXY" or "XYWH" etc. """ - _transformed_types = (datapoints.BoundingBox,) + _transformed_types = (datapoints.BoundingBoxes,) def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None: super().__init__() @@ -27,61 +23,20 @@ def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None: format = datapoints.BoundingBoxFormat[format] self.format = format - def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: - return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value] - - -class ConvertDtype(Transform): - """[BETA] Convert input image or video to the given ``dtype`` and scale the values accordingly. - - .. v2betastatus:: ConvertDtype transform - - This function does not support PIL Image. - - Args: - dtype (torch.dtype): Desired data type of the output - - .. note:: - - When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. - If converted back and forth, this mismatch has no effect. - - Raises: - RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as - well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to - overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range - of the integer ``dtype``. - """ - - _v1_transform_cls = _transforms.ConvertImageDtype - - _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) - - def __init__(self, dtype: torch.dtype = torch.float32) -> None: - super().__init__() - self.dtype = dtype - - def _transform( - self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any] - ) -> Union[datapoints._TensorImageType, datapoints._TensorVideoType]: - return F.convert_dtype(inpt, self.dtype) - - -# We changed the name to align it with the new naming scheme. Still, `ConvertImageDtype` is -# prevalent and well understood. Thus, we just alias it without deprecating the old name. -ConvertImageDtype = ConvertDtype + def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes: + return F.convert_format_bounding_boxes(inpt, new_format=self.format) # type: ignore[return-value] -class ClampBoundingBox(Transform): +class ClampBoundingBoxes(Transform): """[BETA] Clamp bounding boxes to their corresponding image dimensions. - The clamping is done according to the bounding boxes' ``spatial_size`` meta-data. + The clamping is done according to the bounding boxes' ``canvas_size`` meta-data. - .. v2betastatus:: ClampBoundingBox transform + .. v2betastatus:: ClampBoundingBoxes transform """ - _transformed_types = (datapoints.BoundingBox,) + _transformed_types = (datapoints.BoundingBoxes,) - def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: - return F.clamp_bounding_box(inpt) # type: ignore[return-value] + def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes: + return F.clamp_bounding_boxes(inpt) # type: ignore[return-value] diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 90741c4ec7d..a799070ee1e 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -1,7 +1,5 @@ -import collections import warnings -from contextlib import suppress -from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union import PIL.Image @@ -11,8 +9,8 @@ from torchvision import datapoints, transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform -from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size -from .utils import has_any, is_simple_tensor, query_bounding_box +from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size +from .utils import has_any, is_simple_tensor, query_bounding_boxes # TODO: do we want/need to expose this? @@ -225,48 +223,125 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class ToDtype(Transform): - """[BETA] Converts the input to a specific dtype - this does not scale values. + """[BETA] Converts the input to a specific dtype, optionally scaling the values for images or videos. .. v2betastatus:: ToDtype transform + .. note:: + ``ToDtype(dtype, scale=True)`` is the recommended replacement for ``ConvertImageDtype(dtype)``. + Args: dtype (``torch.dtype`` or dict of ``Datapoint`` -> ``torch.dtype``): The dtype to convert to. + If a ``torch.dtype`` is passed, e.g. ``torch.float32``, only images and videos will be converted + to that dtype: this is for compatibility with :class:`~torchvision.transforms.v2.ConvertImageDtype`. A dict can be passed to specify per-datapoint conversions, e.g. - ``dtype={datapoints.Image: torch.float32, datapoints.Video: - torch.float64}``. + ``dtype={datapoints.Image: torch.float32, datapoints.Mask: torch.int64, "others":None}``. The "others" + key can be used as a catch-all for any other datapoint type, and ``None`` means no conversion. + scale (bool, optional): Whether to scale the values for images or videos. Default: ``False``. """ _transformed_types = (torch.Tensor,) - def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None: + def __init__( + self, dtype: Union[torch.dtype, Dict[Union[Type, str], Optional[torch.dtype]]], scale: bool = False + ) -> None: super().__init__() - if not isinstance(dtype, dict): - dtype = _get_defaultdict(dtype) - if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]): + + if not isinstance(dtype, (dict, torch.dtype)): + raise ValueError(f"dtype must be a dict or a torch.dtype, got {type(dtype)} instead") + + if ( + isinstance(dtype, dict) + and torch.Tensor in dtype + and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]) + ): warnings.warn( "Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " "in case a `datapoints.Image` or `datapoints.Video` is present in the input." ) self.dtype = dtype + self.scale = scale def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - dtype = self.dtype[type(inpt)] + if isinstance(self.dtype, torch.dtype): + # For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype + # is a simple torch.dtype + if not is_simple_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)): + return inpt + + dtype: Optional[torch.dtype] = self.dtype + elif type(inpt) in self.dtype: + dtype = self.dtype[type(inpt)] + elif "others" in self.dtype: + dtype = self.dtype["others"] + else: + raise ValueError( + f"No dtype was specified for type {type(inpt)}. " + "If you only need to convert the dtype of images or videos, you can just pass e.g. dtype=torch.float32. " + "If you're passing a dict as dtype, " + 'you can use "others" as a catch-all key ' + 'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.' + ) + + supports_scaling = is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)) if dtype is None: + if self.scale and supports_scaling: + warnings.warn( + "scale was set to True but no dtype was specified for images or videos: no scaling will be done." + ) return inpt - return inpt.to(dtype=dtype) + return F.to_dtype(inpt, dtype=dtype, scale=self.scale) + + +class ConvertImageDtype(Transform): + """[BETA] Convert input image to the given ``dtype`` and scale the values accordingly. + + .. v2betastatus:: ConvertImageDtype transform + + .. warning:: + Consider using ``ToDtype(dtype, scale=True)`` instead. See :class:`~torchvision.transforms.v2.ToDtype`. + + This function does not support PIL Image. + + Args: + dtype (torch.dtype): Desired data type of the output + + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. + """ + + _v1_transform_cls = _transforms.ConvertImageDtype + + _transformed_types = (is_simple_tensor, datapoints.Image) + + def __init__(self, dtype: torch.dtype = torch.float32) -> None: + super().__init__() + self.dtype = dtype -class SanitizeBoundingBox(Transform): + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.to_dtype(inpt, dtype=self.dtype, scale=True) + + +class SanitizeBoundingBoxes(Transform): """[BETA] Remove degenerate/invalid bounding boxes and their corresponding labels and masks. - .. v2betastatus:: SanitizeBoundingBox transform + .. v2betastatus:: SanitizeBoundingBoxes transform This transform removes bounding boxes and their associated labels/masks that: - are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1. - have any coordinate outside of their corresponding image. You may want to - call :class:`~torchvision.transforms.v2.ClampBoundingBox` first to avoid undesired removals. + call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals. It is recommended to call it at the end of a pipeline, before passing the input to the models. It is critical to call this transform if @@ -278,12 +353,11 @@ class SanitizeBoundingBox(Transform): Args: min_size (float, optional) The size below which bounding boxes are removed. Default is 1. labels_getter (callable or str or None, optional): indicates how to identify the labels in the input. - It can be a str in which case the input is expected to be a dict, and ``labels_getter`` then specifies - the key whose value corresponds to the labels. It can also be a callable that takes the same input - as the transform, and returns the labels. - By default, this will try to find a "labels" key in the input, if + By default, this will try to find a "labels" key in the input (case-insensitive), if the input is a dict or it is a tuple whose second element is a dict. This heuristic should work well with a lot of datasets, including the built-in torchvision datasets. + It can also be a callable that takes the same input + as the transform, and returns the labels. """ def __init__( @@ -298,72 +372,22 @@ def __init__( self.min_size = min_size self.labels_getter = labels_getter - self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]] - if labels_getter == "default": - self._labels_getter = self._find_labels_default_heuristic - elif callable(labels_getter): - self._labels_getter = labels_getter - elif isinstance(labels_getter, str): - self._labels_getter = lambda inputs: SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)[ - labels_getter # type: ignore[index] - ] - elif labels_getter is None: - self._labels_getter = None - else: - raise ValueError( - "labels_getter should either be a str, callable, or 'default'. " - f"Got {labels_getter} of type {type(labels_getter)}." - ) - - @staticmethod - def _get_dict_or_second_tuple_entry(inputs: Any) -> Mapping[str, Any]: - # datasets outputs may be plain dicts like {"img": ..., "labels": ..., "bbox": ...} - # or tuples like (img, {"labels":..., "bbox": ...}) - # This hacky helper accounts for both structures. - if isinstance(inputs, tuple): - inputs = inputs[1] - - if not isinstance(inputs, collections.abc.Mapping): - raise ValueError( - f"If labels_getter is a str or 'default', " - f"then the input to forward() must be a dict or a tuple whose second element is a dict." - f" Got {type(inputs)} instead." - ) - return inputs - - @staticmethod - def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]: - # Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive - # Returns None if nothing is found - inputs = SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs) - candidate_key = None - with suppress(StopIteration): - candidate_key = next(key for key in inputs.keys() if key.lower() == "labels") - if candidate_key is None: - with suppress(StopIteration): - candidate_key = next(key for key in inputs.keys() if "label" in key.lower()) - if candidate_key is None: - raise ValueError( - "Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?" - "If there are no samples and it is by design, pass labels_getter=None." - ) - return inputs[candidate_key] + self._labels_getter = _parse_labels_getter(labels_getter) def forward(self, *inputs: Any) -> Any: inputs = inputs if len(inputs) > 1 else inputs[0] - if self._labels_getter is None: - labels = None - else: - labels = self._labels_getter(inputs) - if labels is not None and not isinstance(labels, torch.Tensor): - raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.") + labels = self._labels_getter(inputs) + if labels is not None and not isinstance(labels, torch.Tensor): + raise ValueError( + f"The labels in the input to forward() must be a tensor or None, got {type(labels)} instead." + ) flat_inputs, spec = tree_flatten(inputs) - # TODO: this enforces one single BoundingBox entry. + # TODO: this enforces one single BoundingBoxes entry. # Assuming this transform needs to be called at the end of *any* pipeline that has bboxes... # should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? - boxes = query_bounding_box(flat_inputs) + boxes = query_bounding_boxes(flat_inputs) if boxes.ndim != 2: raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}") @@ -374,8 +398,8 @@ def forward(self, *inputs: Any) -> Any: ) boxes = cast( - datapoints.BoundingBox, - F.convert_format_bounding_box( + datapoints.BoundingBoxes, + F.convert_format_bounding_boxes( boxes, new_format=datapoints.BoundingBoxFormat.XYXY, ), @@ -384,14 +408,14 @@ def forward(self, *inputs: Any) -> Any: valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) # TODO: Do we really need to check for out of bounds here? All # transforms should be clamping anyway, so this should never happen? - image_h, image_w = boxes.spatial_size + image_h, image_w = boxes.canvas_size valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) params = dict(valid=valid, labels=labels) flat_outputs = [ # Even-though it may look like we're transforming all inputs, we don't: - # _transform() will only care about BoundingBoxes and the labels + # _transform() will only care about BoundingBoxeses and the labels self._transform(inpt, params) for inpt in flat_inputs ] @@ -400,9 +424,9 @@ def forward(self, *inputs: Any) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: is_label = inpt is not None and inpt is params["labels"] - is_bounding_box_or_mask = isinstance(inpt, (datapoints.BoundingBox, datapoints.Mask)) + is_bounding_boxes_or_mask = isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)) - if not (is_label or is_bounding_box_or_mask): + if not (is_label or is_bounding_boxes_or_mask): return inpt output = inpt[params["valid"]] diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 9942602ebb9..a7826a6645f 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -1,11 +1,12 @@ -import functools +import collections.abc import numbers -from collections import defaultdict -from typing import Any, Dict, Literal, Sequence, Type, TypeVar, Union +from contextlib import suppress +from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union + +import torch from torchvision import datapoints from torchvision.datapoints._datapoint import _FillType, _FillTypeJIT - from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 @@ -26,32 +27,15 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: return arg -def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None: +def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> None: if isinstance(fill, dict): - for key, value in fill.items(): - # Check key for type + for value in fill.values(): _check_fill_arg(value) - if isinstance(fill, defaultdict) and callable(fill.default_factory): - default_value = fill.default_factory() - _check_fill_arg(default_value) else: if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.") -T = TypeVar("T") - - -def _default_arg(value: T) -> T: - return value - - -def _get_defaultdict(default: T) -> Dict[Any, T]: - # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. - # If it were possible, we could replace this with `defaultdict(lambda: default)` - return defaultdict(functools.partial(_default_arg, default)) - - def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT: # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # So, we can't reassign fill to 0 @@ -65,19 +49,24 @@ def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT: return fill # type: ignore[return-value] -def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]: +def _setup_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> Dict[Union[Type, str], _FillTypeJIT]: _check_fill_arg(fill) if isinstance(fill, dict): for k, v in fill.items(): fill[k] = _convert_fill_arg(v) - if isinstance(fill, defaultdict) and callable(fill.default_factory): - default_value = fill.default_factory() - sanitized_default = _convert_fill_arg(default_value) - fill.default_factory = functools.partial(_default_arg, sanitized_default) return fill # type: ignore[return-value] + else: + return {"others": _convert_fill_arg(fill)} - return _get_defaultdict(_convert_fill_arg(fill)) + +def _get_fill(fill_dict, inpt_type): + if inpt_type in fill_dict: + return fill_dict[inpt_type] + elif "others" in fill_dict: + return fill_dict["others"] + else: + RuntimeError("This should never happen, please open an issue on the torchvision repo if you hit this.") def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: @@ -93,3 +82,60 @@ def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + + +def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor: + """ + This heuristic covers three cases: + + 1. The input is tuple or list whose second item is a labels tensor. This happens for already batched + classification inputs for MixUp and CutMix (typically after the Dataloder). + 2. The input is a tuple or list whose second item is a dictionary that contains the labels tensor + under a label-like (see below) key. This happens for the inputs of detection models. + 3. The input is a dictionary that is structured as the one from 2. + + What is "label-like" key? We first search for an case-insensitive match of 'labels' inside the keys of the + dictionary. This is the name our detection models expect. If we can't find that, we look for a case-insensitive + match of the term 'label' anywhere inside the key, i.e. 'FooLaBeLBar'. If we can't find that either, the dictionary + contains no "label-like" key. + """ + + if isinstance(inputs, (tuple, list)): + inputs = inputs[1] + + # MixUp, CutMix + if isinstance(inputs, torch.Tensor): + return inputs + + if not isinstance(inputs, collections.abc.Mapping): + raise ValueError( + f"When using the default labels_getter, the input passed to forward must be a dictionary or a two-tuple " + f"whose second item is a dictionary or a tensor, but got {inputs} instead." + ) + + candidate_key = None + with suppress(StopIteration): + candidate_key = next(key for key in inputs.keys() if key.lower() == "labels") + if candidate_key is None: + with suppress(StopIteration): + candidate_key = next(key for key in inputs.keys() if "label" in key.lower()) + if candidate_key is None: + raise ValueError( + "Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?" + "If there are no labels in the sample by design, pass labels_getter=None." + ) + + return inputs[candidate_key] + + +def _parse_labels_getter( + labels_getter: Union[str, Callable[[Any], Optional[torch.Tensor]], None] +) -> Callable[[Any], Optional[torch.Tensor]]: + if labels_getter == "default": + return _find_labels_default_heuristic + elif callable(labels_getter): + return labels_getter + elif labels_getter is None: + return lambda _: None + else: + raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.") diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index b4803f4f1b9..24b4b4218e0 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -3,12 +3,8 @@ from ._utils import is_simple_tensor # usort: skip from ._meta import ( - clamp_bounding_box, - convert_format_bounding_box, - convert_dtype_image_tensor, - convert_dtype, - convert_dtype_video, - convert_image_dtype, + clamp_bounding_boxes, + convert_format_bounding_boxes, get_dimensions_image_tensor, get_dimensions_image_pil, get_dimensions, @@ -19,12 +15,12 @@ get_num_channels_image_pil, get_num_channels_video, get_num_channels, - get_spatial_size_bounding_box, - get_spatial_size_image_tensor, - get_spatial_size_image_pil, - get_spatial_size_mask, - get_spatial_size_video, - get_spatial_size, + get_size_bounding_boxes, + get_size_image_tensor, + get_size_image_pil, + get_size_mask, + get_size_video, + get_size, ) # usort: skip from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video @@ -80,25 +76,25 @@ ) from ._geometry import ( affine, - affine_bounding_box, + affine_bounding_boxes, affine_image_pil, affine_image_tensor, affine_mask, affine_video, center_crop, - center_crop_bounding_box, + center_crop_bounding_boxes, center_crop_image_pil, center_crop_image_tensor, center_crop_mask, center_crop_video, crop, - crop_bounding_box, + crop_bounding_boxes, crop_image_pil, crop_image_tensor, crop_mask, crop_video, elastic, - elastic_bounding_box, + elastic_bounding_boxes, elastic_image_pil, elastic_image_tensor, elastic_mask, @@ -110,37 +106,37 @@ five_crop_video, hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file horizontal_flip, - horizontal_flip_bounding_box, + horizontal_flip_bounding_boxes, horizontal_flip_image_pil, horizontal_flip_image_tensor, horizontal_flip_mask, horizontal_flip_video, pad, - pad_bounding_box, + pad_bounding_boxes, pad_image_pil, pad_image_tensor, pad_mask, pad_video, perspective, - perspective_bounding_box, + perspective_bounding_boxes, perspective_image_pil, perspective_image_tensor, perspective_mask, perspective_video, resize, - resize_bounding_box, + resize_bounding_boxes, resize_image_pil, resize_image_tensor, resize_mask, resize_video, resized_crop, - resized_crop_bounding_box, + resized_crop_bounding_boxes, resized_crop_image_pil, resized_crop_image_tensor, resized_crop_mask, resized_crop_video, rotate, - rotate_bounding_box, + rotate_bounding_boxes, rotate_image_pil, rotate_image_tensor, rotate_mask, @@ -150,7 +146,7 @@ ten_crop_image_tensor, ten_crop_video, vertical_flip, - vertical_flip_bounding_box, + vertical_flip_bounding_boxes, vertical_flip_image_pil, vertical_flip_image_tensor, vertical_flip_mask, @@ -158,6 +154,7 @@ vflip, ) from ._misc import ( + convert_image_dtype, gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, @@ -165,6 +162,9 @@ normalize, normalize_image_tensor, normalize_video, + to_dtype, + to_dtype_image_tensor, + to_dtype_video, ) from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 13417e4a990..32568f728cf 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -9,7 +9,7 @@ from torchvision.utils import _log_api_usage_once -from ._meta import _num_value_bits, convert_dtype_image_tensor +from ._misc import _num_value_bits, to_dtype_image_tensor from ._utils import is_simple_tensor @@ -351,7 +351,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten return image orig_dtype = image.dtype - image = convert_dtype_image_tensor(image, torch.float32) + image = to_dtype_image_tensor(image, torch.float32, scale=True) image = _rgb_to_hsv(image) h, s, v = image.unbind(dim=-3) @@ -359,7 +359,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten image = torch.stack((h, s, v), dim=-3) image_hue_adj = _hsv_to_rgb(image) - return convert_dtype_image_tensor(image_hue_adj, orig_dtype) + return to_dtype_image_tensor(image_hue_adj, orig_dtype, scale=True) adjust_hue_image_pil = _FP.adjust_hue @@ -393,7 +393,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 # The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer). # Since the gamma is non-negative, the output remains at [0, 1] scale. if not torch.is_floating_point(image): - output = convert_dtype_image_tensor(image, torch.float32).pow_(gamma) + output = to_dtype_image_tensor(image, torch.float32, scale=True).pow_(gamma) else: output = image.pow(gamma) @@ -402,7 +402,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 # of the output can go beyond [0, 1]. output = output.mul_(gain).clamp_(0.0, 1.0) - return convert_dtype_image_tensor(output, image.dtype) + return to_dtype_image_tensor(output, image.dtype, scale=True) adjust_gamma_image_pil = _FP.adjust_gamma @@ -565,7 +565,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: # Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is # by far the most common, we choose it as base. output_dtype = image.dtype - image = convert_dtype_image_tensor(image, torch.uint8) + image = to_dtype_image_tensor(image, torch.uint8, scale=True) # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image # corresponds to adding 1 to index 127 in the histogram. @@ -616,7 +616,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image) output = torch.where(valid_equalization, equalized_image, image) - return convert_dtype_image_tensor(output, output_dtype) + return to_dtype_image_tensor(output, output_dtype, scale=True) equalize_image_pil = _FP.equalize diff --git a/torchvision/transforms/v2/functional/_deprecated.py b/torchvision/transforms/v2/functional/_deprecated.py index c9a0f647e60..f27d0b29deb 100644 --- a/torchvision/transforms/v2/functional/_deprecated.py +++ b/torchvision/transforms/v2/functional/_deprecated.py @@ -11,7 +11,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: warnings.warn( "The function `to_tensor(...)` is deprecated and will be removed in a future release. " - "Instead, please use `to_image_tensor(...)` followed by `convert_image_dtype(...)`." + "Instead, please use `to_image_tensor(...)` followed by `to_dtype(..., dtype=torch.float32, scale=True)`." ) return _F.to_tensor(inpt) @@ -19,6 +19,6 @@ def to_tensor(inpt: Any) -> torch.Tensor: def get_image_size(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: warnings.warn( "The function `get_image_size(...)` is deprecated and will be removed in a future release. " - "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." + "Instead, please use `get_size(...)` which returns `[h, w]` instead of `[w, h]`." ) return _F.get_image_size(inpt) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index e1dd2866bc5..a24507256be 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -23,7 +23,7 @@ from torchvision.utils import _log_api_usage_once -from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_size_image_pil +from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil from ._utils import is_simple_tensor @@ -51,21 +51,21 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: return horizontal_flip_image_tensor(mask) -def horizontal_flip_bounding_box( - bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int] +def horizontal_flip_bounding_boxes( + bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, canvas_size: Tuple[int, int] ) -> torch.Tensor: - shape = bounding_box.shape + shape = bounding_boxes.shape - bounding_box = bounding_box.clone().reshape(-1, 4) + bounding_boxes = bounding_boxes.clone().reshape(-1, 4) if format == datapoints.BoundingBoxFormat.XYXY: - bounding_box[:, [2, 0]] = bounding_box[:, [0, 2]].sub_(spatial_size[1]).neg_() + bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_() elif format == datapoints.BoundingBoxFormat.XYWH: - bounding_box[:, 0].add_(bounding_box[:, 2]).sub_(spatial_size[1]).neg_() + bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_() else: # format == datapoints.BoundingBoxFormat.CXCYWH: - bounding_box[:, 0].sub_(spatial_size[1]).neg_() + bounding_boxes[:, 0].sub_(canvas_size[1]).neg_() - return bounding_box.reshape(shape) + return bounding_boxes.reshape(shape) def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: @@ -101,21 +101,21 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: return vertical_flip_image_tensor(mask) -def vertical_flip_bounding_box( - bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int] +def vertical_flip_bounding_boxes( + bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, canvas_size: Tuple[int, int] ) -> torch.Tensor: - shape = bounding_box.shape + shape = bounding_boxes.shape - bounding_box = bounding_box.clone().reshape(-1, 4) + bounding_boxes = bounding_boxes.clone().reshape(-1, 4) if format == datapoints.BoundingBoxFormat.XYXY: - bounding_box[:, [1, 3]] = bounding_box[:, [3, 1]].sub_(spatial_size[0]).neg_() + bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_() elif format == datapoints.BoundingBoxFormat.XYWH: - bounding_box[:, 1].add_(bounding_box[:, 3]).sub_(spatial_size[0]).neg_() + bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_() else: # format == datapoints.BoundingBoxFormat.CXCYWH: - bounding_box[:, 1].sub_(spatial_size[0]).neg_() + bounding_boxes[:, 1].sub_(canvas_size[0]).neg_() - return bounding_box.reshape(shape) + return bounding_boxes.reshape(shape) def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: @@ -146,7 +146,7 @@ def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: def _compute_resized_output_size( - spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None + canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None ) -> List[int]: if isinstance(size, int): size = [size] @@ -155,7 +155,7 @@ def _compute_resized_output_size( "max_size should only be passed if size specifies the length of the smaller edge, " "i.e. size should be an int or a sequence of length 1 in torchscript mode." ) - return __compute_resized_output_size(spatial_size, size=size, max_size=max_size) + return __compute_resized_output_size(canvas_size, size=size, max_size=max_size) def resize_image_tensor( @@ -274,20 +274,20 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N return output -def resize_bounding_box( - bounding_box: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None +def resize_bounding_boxes( + bounding_boxes: torch.Tensor, canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None ) -> Tuple[torch.Tensor, Tuple[int, int]]: - old_height, old_width = spatial_size - new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size) + old_height, old_width = canvas_size + new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size) if (new_height, new_width) == (old_height, old_width): - return bounding_box, spatial_size + return bounding_boxes, canvas_size w_ratio = new_width / old_width h_ratio = new_height / old_height - ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_box.device) + ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device) return ( - bounding_box.mul(ratios).to(bounding_box.dtype), + bounding_boxes.mul(ratios).to(bounding_boxes.dtype), (new_height, new_width), ) @@ -643,17 +643,17 @@ def affine_image_pil( # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine if center is None: - height, width = get_spatial_size_image_pil(image) + height, width = get_size_image_pil(image) center = [width * 0.5, height * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill) -def _affine_bounding_box_with_expand( - bounding_box: torch.Tensor, +def _affine_bounding_boxes_with_expand( + bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, - spatial_size: Tuple[int, int], + canvas_size: Tuple[int, int], angle: Union[int, float], translate: List[float], scale: float, @@ -661,17 +661,17 @@ def _affine_bounding_box_with_expand( center: Optional[List[float]] = None, expand: bool = False, ) -> Tuple[torch.Tensor, Tuple[int, int]]: - if bounding_box.numel() == 0: - return bounding_box, spatial_size - - original_shape = bounding_box.shape - original_dtype = bounding_box.dtype - bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float() - dtype = bounding_box.dtype - device = bounding_box.device - bounding_box = ( - convert_format_bounding_box( - bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True + if bounding_boxes.numel() == 0: + return bounding_boxes, canvas_size + + original_shape = bounding_boxes.shape + original_dtype = bounding_boxes.dtype + bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float() + dtype = bounding_boxes.dtype + device = bounding_boxes.device + bounding_boxes = ( + convert_format_bounding_boxes( + bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True ) ).reshape(-1, 4) @@ -680,7 +680,7 @@ def _affine_bounding_box_with_expand( ) if center is None: - height, width = spatial_size + height, width = canvas_size center = [width * 0.5, height * 0.5] affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False) @@ -697,7 +697,7 @@ def _affine_bounding_box_with_expand( # Tensor of points has shape (N * 4, 3), where N is the number of bboxes # Single point structure is similar to # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] - points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) + points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1) # 2) Now let's transform the points using affine matrix transformed_points = torch.matmul(points, transposed_affine_matrix) @@ -710,7 +710,7 @@ def _affine_bounding_box_with_expand( if expand: # Compute minimum point for transformed image frame: # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. - height, width = spatial_size + height, width = canvas_size points = torch.tensor( [ [0.0, 0.0, 1.0], @@ -728,31 +728,31 @@ def _affine_bounding_box_with_expand( # Estimate meta-data for image with inverted=True and with center=[0,0] affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) new_width, new_height = _compute_affine_output_size(affine_vector, width, height) - spatial_size = (new_height, new_width) + canvas_size = (new_height, new_width) - out_bboxes = clamp_bounding_box(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size) - out_bboxes = convert_format_bounding_box( + out_bboxes = clamp_bounding_boxes(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size) + out_bboxes = convert_format_bounding_boxes( out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True ).reshape(original_shape) out_bboxes = out_bboxes.to(original_dtype) - return out_bboxes, spatial_size + return out_bboxes, canvas_size -def affine_bounding_box( - bounding_box: torch.Tensor, +def affine_bounding_boxes( + bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, - spatial_size: Tuple[int, int], + canvas_size: Tuple[int, int], angle: Union[int, float], translate: List[float], scale: float, shear: List[float], center: Optional[List[float]] = None, ) -> torch.Tensor: - out_box, _ = _affine_bounding_box_with_expand( - bounding_box, + out_box, _ = _affine_bounding_boxes_with_expand( + bounding_boxes, format=format, - spatial_size=spatial_size, + canvas_size=canvas_size, angle=angle, translate=translate, scale=scale, @@ -927,10 +927,10 @@ def rotate_image_pil( ) -def rotate_bounding_box( - bounding_box: torch.Tensor, +def rotate_bounding_boxes( + bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, - spatial_size: Tuple[int, int], + canvas_size: Tuple[int, int], angle: float, expand: bool = False, center: Optional[List[float]] = None, @@ -938,10 +938,10 @@ def rotate_bounding_box( if center is not None and expand: warnings.warn("The provided center argument has no effect on the result if expand is True") - return _affine_bounding_box_with_expand( - bounding_box, + return _affine_bounding_boxes_with_expand( + bounding_boxes, format=format, - spatial_size=spatial_size, + canvas_size=canvas_size, angle=-angle, translate=[0.0, 0.0], scale=1.0, @@ -1165,10 +1165,10 @@ def pad_mask( return output -def pad_bounding_box( - bounding_box: torch.Tensor, +def pad_bounding_boxes( + bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, - spatial_size: Tuple[int, int], + canvas_size: Tuple[int, int], padding: List[int], padding_mode: str = "constant", ) -> Tuple[torch.Tensor, Tuple[int, int]]: @@ -1182,14 +1182,14 @@ def pad_bounding_box( pad = [left, top, left, top] else: pad = [left, top, 0, 0] - bounding_box = bounding_box + torch.tensor(pad, dtype=bounding_box.dtype, device=bounding_box.device) + bounding_boxes = bounding_boxes + torch.tensor(pad, dtype=bounding_boxes.dtype, device=bounding_boxes.device) - height, width = spatial_size + height, width = canvas_size height += top + bottom width += left + right - spatial_size = (height, width) + canvas_size = (height, width) - return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size + return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size def pad_video( @@ -1245,8 +1245,8 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid crop_image_pil = _FP.crop -def crop_bounding_box( - bounding_box: torch.Tensor, +def crop_bounding_boxes( + bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, top: int, left: int, @@ -1260,10 +1260,10 @@ def crop_bounding_box( else: sub = [left, top, 0, 0] - bounding_box = bounding_box - torch.tensor(sub, dtype=bounding_box.dtype, device=bounding_box.device) - spatial_size = (height, width) + bounding_boxes = bounding_boxes - torch.tensor(sub, dtype=bounding_boxes.dtype, device=bounding_boxes.device) + canvas_size = (height, width) - return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size + return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: @@ -1409,27 +1409,27 @@ def perspective_image_pil( return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) -def perspective_bounding_box( - bounding_box: torch.Tensor, +def perspective_bounding_boxes( + bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, - spatial_size: Tuple[int, int], + canvas_size: Tuple[int, int], startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], coefficients: Optional[List[float]] = None, ) -> torch.Tensor: - if bounding_box.numel() == 0: - return bounding_box + if bounding_boxes.numel() == 0: + return bounding_boxes perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) - original_shape = bounding_box.shape - # TODO: first cast to float if bbox is int64 before convert_format_bounding_box - bounding_box = ( - convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) + original_shape = bounding_boxes.shape + # TODO: first cast to float if bbox is int64 before convert_format_bounding_boxes + bounding_boxes = ( + convert_format_bounding_boxes(bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) ).reshape(-1, 4) - dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 - device = bounding_box.device + dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32 + device = bounding_boxes.device # perspective_coeffs are computed as endpoint -> start point # We have to invert perspective_coeffs for bboxes: @@ -1475,7 +1475,7 @@ def perspective_bounding_box( # Tensor of points has shape (N * 4, 3), where N is the number of bboxes # Single point structure is similar to # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] - points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) + points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1) # 2) Now let's transform the points using perspective matrices # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) @@ -1490,15 +1490,15 @@ def perspective_bounding_box( transformed_points = transformed_points.reshape(-1, 4, 2) out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) - out_bboxes = clamp_bounding_box( - torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype), + out_bboxes = clamp_bounding_boxes( + torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype), format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=spatial_size, + canvas_size=canvas_size, ) # out_bboxes should be of shape [N boxes, 4] - return convert_format_bounding_box( + return convert_format_bounding_boxes( out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True ).reshape(original_shape) @@ -1648,53 +1648,53 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to return base_grid -def elastic_bounding_box( - bounding_box: torch.Tensor, +def elastic_bounding_boxes( + bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, - spatial_size: Tuple[int, int], + canvas_size: Tuple[int, int], displacement: torch.Tensor, ) -> torch.Tensor: - if bounding_box.numel() == 0: - return bounding_box + if bounding_boxes.numel() == 0: + return bounding_boxes # TODO: add in docstring about approximation we are doing for grid inversion - device = bounding_box.device - dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 + device = bounding_boxes.device + dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32 if displacement.dtype != dtype or displacement.device != device: displacement = displacement.to(dtype=dtype, device=device) - original_shape = bounding_box.shape - # TODO: first cast to float if bbox is int64 before convert_format_bounding_box - bounding_box = ( - convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) + original_shape = bounding_boxes.shape + # TODO: first cast to float if bbox is int64 before convert_format_bounding_boxes + bounding_boxes = ( + convert_format_bounding_boxes(bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) ).reshape(-1, 4) - id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype) + id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype) # We construct an approximation of inverse grid as inv_grid = id_grid - displacement # This is not an exact inverse of the grid inv_grid = id_grid.sub_(displacement) # Get points from bboxes - points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) + points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) if points.is_floating_point(): points = points.ceil_() index_xy = points.to(dtype=torch.long) index_x, index_y = index_xy[:, 0], index_xy[:, 1] # Transform points: - t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype) + t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype) transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) transformed_points = transformed_points.reshape(-1, 4, 2) out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) - out_bboxes = clamp_bounding_box( - torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype), + out_bboxes = clamp_bounding_boxes( + torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype), format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=spatial_size, + canvas_size=canvas_size, ) - return convert_format_bounding_box( + return convert_format_bounding_boxes( out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True ).reshape(original_shape) @@ -1804,13 +1804,13 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor @torch.jit.unused def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: crop_height, crop_width = _center_crop_parse_output_size(output_size) - image_height, image_width = get_spatial_size_image_pil(image) + image_height, image_width = get_size_image_pil(image) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) image = pad_image_pil(image, padding_ltrb, fill=0) - image_height, image_width = get_spatial_size_image_pil(image) + image_height, image_width = get_size_image_pil(image) if crop_width == image_width and crop_height == image_height: return image @@ -1818,15 +1818,17 @@ def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL return crop_image_pil(image, crop_top, crop_left, crop_height, crop_width) -def center_crop_bounding_box( - bounding_box: torch.Tensor, +def center_crop_bounding_boxes( + bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, - spatial_size: Tuple[int, int], + canvas_size: Tuple[int, int], output_size: List[int], ) -> Tuple[torch.Tensor, Tuple[int, int]]: crop_height, crop_width = _center_crop_parse_output_size(output_size) - crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *spatial_size) - return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width) + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size) + return crop_bounding_boxes( + bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width + ) def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor: @@ -1893,8 +1895,8 @@ def resized_crop_image_pil( return resize_image_pil(image, size, interpolation=interpolation) -def resized_crop_bounding_box( - bounding_box: torch.Tensor, +def resized_crop_bounding_boxes( + bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, top: int, left: int, @@ -1902,8 +1904,8 @@ def resized_crop_bounding_box( width: int, size: List[int], ) -> Tuple[torch.Tensor, Tuple[int, int]]: - bounding_box, _ = crop_bounding_box(bounding_box, format, top, left, height, width) - return resize_bounding_box(bounding_box, spatial_size=(height, width), size=size) + bounding_boxes, _ = crop_bounding_boxes(bounding_boxes, format, top, left, height, width) + return resize_bounding_boxes(bounding_boxes, canvas_size=(height, width), size=size) def resized_crop_mask( @@ -1998,7 +2000,7 @@ def five_crop_image_pil( image: PIL.Image.Image, size: List[int] ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: crop_height, crop_width = _parse_five_crop_size(size) - image_height, image_width = get_spatial_size_image_pil(image) + image_height, image_width = get_size_image_pil(image) if crop_width > image_width or crop_height > image_height: raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 8ffa3966195..91b370675b9 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -5,7 +5,6 @@ from torchvision import datapoints from torchvision.datapoints import BoundingBoxFormat from torchvision.transforms import _functional_pil as _FP -from torchvision.transforms._functional_tensor import _max_value from torchvision.utils import _log_api_usage_once @@ -27,23 +26,29 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: get_dimensions_image_pil = _FP.get_dimensions +def get_dimensions_video(video: torch.Tensor) -> List[int]: + return get_dimensions_image_tensor(video) + + def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: if not torch.jit.is_scripting(): _log_api_usage_once(get_dimensions) if torch.jit.is_scripting() or is_simple_tensor(inpt): return get_dimensions_image_tensor(inpt) - elif isinstance(inpt, (datapoints.Image, datapoints.Video)): - channels = inpt.num_channels - height, width = inpt.spatial_size - return [channels, height, width] - elif isinstance(inpt, PIL.Image.Image): - return get_dimensions_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + for typ, get_size_fn in { + datapoints.Image: get_dimensions_image_tensor, + datapoints.Video: get_dimensions_video, + PIL.Image.Image: get_dimensions_image_pil, + }.items(): + if isinstance(inpt, typ): + return get_size_fn(inpt) + + raise TypeError( + f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) def get_num_channels_image_tensor(image: torch.Tensor) -> int: @@ -70,15 +75,19 @@ def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoType if torch.jit.is_scripting() or is_simple_tensor(inpt): return get_num_channels_image_tensor(inpt) - elif isinstance(inpt, (datapoints.Image, datapoints.Video)): - return inpt.num_channels - elif isinstance(inpt, PIL.Image.Image): - return get_num_channels_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + for typ, get_size_fn in { + datapoints.Image: get_num_channels_image_tensor, + datapoints.Video: get_num_channels_video, + PIL.Image.Image: get_num_channels_image_pil, + }.items(): + if isinstance(inpt, typ): + return get_size_fn(inpt) + + raise TypeError( + f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) # We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without @@ -86,7 +95,7 @@ def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoType get_image_num_channels = get_num_channels -def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]: +def get_size_image_tensor(image: torch.Tensor) -> List[int]: hw = list(image.shape[-2:]) ndims = len(hw) if ndims == 2: @@ -96,39 +105,48 @@ def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]: @torch.jit.unused -def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]: +def get_size_image_pil(image: PIL.Image.Image) -> List[int]: width, height = _FP.get_image_size(image) return [height, width] -def get_spatial_size_video(video: torch.Tensor) -> List[int]: - return get_spatial_size_image_tensor(video) +def get_size_video(video: torch.Tensor) -> List[int]: + return get_size_image_tensor(video) -def get_spatial_size_mask(mask: torch.Tensor) -> List[int]: - return get_spatial_size_image_tensor(mask) +def get_size_mask(mask: torch.Tensor) -> List[int]: + return get_size_image_tensor(mask) @torch.jit.unused -def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[int]: - return list(bounding_box.spatial_size) +def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]: + return list(bounding_box.canvas_size) -def get_spatial_size(inpt: datapoints._InputTypeJIT) -> List[int]: +def get_size(inpt: datapoints._InputTypeJIT) -> List[int]: if not torch.jit.is_scripting(): - _log_api_usage_once(get_spatial_size) + _log_api_usage_once(get_size) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return get_spatial_size_image_tensor(inpt) - elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBox, datapoints.Mask)): - return list(inpt.spatial_size) - elif isinstance(inpt, PIL.Image.Image): - return get_spatial_size_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + return get_size_image_tensor(inpt) + + # TODO: This is just the poor mans version of a dispatcher. This will be properly addressed with + # https://github.com/pytorch/vision/pull/7747 when we can register the kernels above without the need to have + # a method on the datapoint class + for typ, get_size_fn in { + datapoints.Image: get_size_image_tensor, + datapoints.BoundingBoxes: get_size_bounding_boxes, + datapoints.Mask: get_size_mask, + datapoints.Video: get_size_video, + PIL.Image.Image: get_size_image_pil, + }.items(): + if isinstance(inpt, typ): + return get_size_fn(inpt) + + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) def get_num_frames_video(video: torch.Tensor) -> int: @@ -142,7 +160,7 @@ def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: if torch.jit.is_scripting() or is_simple_tensor(inpt): return get_num_frames_video(inpt) elif isinstance(inpt, datapoints.Video): - return inpt.num_frames + return get_num_frames_video(inpt) else: raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.") @@ -186,189 +204,97 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: return xyxy -def _convert_format_bounding_box( - bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False +def _convert_format_bounding_boxes( + bounding_boxes: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False ) -> torch.Tensor: if new_format == old_format: - return bounding_box + return bounding_boxes # TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance if old_format == BoundingBoxFormat.XYWH: - bounding_box = _xywh_to_xyxy(bounding_box, inplace) + bounding_boxes = _xywh_to_xyxy(bounding_boxes, inplace) elif old_format == BoundingBoxFormat.CXCYWH: - bounding_box = _cxcywh_to_xyxy(bounding_box, inplace) + bounding_boxes = _cxcywh_to_xyxy(bounding_boxes, inplace) if new_format == BoundingBoxFormat.XYWH: - bounding_box = _xyxy_to_xywh(bounding_box, inplace) + bounding_boxes = _xyxy_to_xywh(bounding_boxes, inplace) elif new_format == BoundingBoxFormat.CXCYWH: - bounding_box = _xyxy_to_cxcywh(bounding_box, inplace) + bounding_boxes = _xyxy_to_cxcywh(bounding_boxes, inplace) - return bounding_box + return bounding_boxes -def convert_format_bounding_box( +def convert_format_bounding_boxes( inpt: datapoints._InputTypeJIT, old_format: Optional[BoundingBoxFormat] = None, new_format: Optional[BoundingBoxFormat] = None, inplace: bool = False, ) -> datapoints._InputTypeJIT: # This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor - # inputs as well as extract it from `datapoints.BoundingBox` inputs. However, putting a default value on + # inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the # default error that would be thrown if `new_format` had no default value. if new_format is None: - raise TypeError("convert_format_bounding_box() missing 1 required argument: 'new_format'") + raise TypeError("convert_format_bounding_boxes() missing 1 required argument: 'new_format'") if not torch.jit.is_scripting(): - _log_api_usage_once(convert_format_bounding_box) + _log_api_usage_once(convert_format_bounding_boxes) if torch.jit.is_scripting() or is_simple_tensor(inpt): if old_format is None: raise ValueError("For simple tensor inputs, `old_format` has to be passed.") - return _convert_format_bounding_box(inpt, old_format=old_format, new_format=new_format, inplace=inplace) - elif isinstance(inpt, datapoints.BoundingBox): + return _convert_format_bounding_boxes(inpt, old_format=old_format, new_format=new_format, inplace=inplace) + elif isinstance(inpt, datapoints.BoundingBoxes): if old_format is not None: raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.") - output = _convert_format_bounding_box( + output = _convert_format_bounding_boxes( inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace ) - return datapoints.BoundingBox.wrap_like(inpt, output, format=new_format) + return datapoints.BoundingBoxes.wrap_like(inpt, output, format=new_format) else: raise TypeError( f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." ) -def _clamp_bounding_box( - bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int] +def _clamp_bounding_boxes( + bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: Tuple[int, int] ) -> torch.Tensor: # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # BoundingBoxFormat instead of converting back and forth - in_dtype = bounding_box.dtype - bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float() - xyxy_boxes = convert_format_bounding_box( - bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True + in_dtype = bounding_boxes.dtype + bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float() + xyxy_boxes = convert_format_bounding_boxes( + bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True ) - xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1]) - xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0]) - out_boxes = convert_format_bounding_box( + xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1]) + xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0]) + out_boxes = convert_format_bounding_boxes( xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True ) return out_boxes.to(in_dtype) -def clamp_bounding_box( +def clamp_bounding_boxes( inpt: datapoints._InputTypeJIT, format: Optional[BoundingBoxFormat] = None, - spatial_size: Optional[Tuple[int, int]] = None, + canvas_size: Optional[Tuple[int, int]] = None, ) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(clamp_bounding_box) + _log_api_usage_once(clamp_bounding_boxes) if torch.jit.is_scripting() or is_simple_tensor(inpt): - if format is None or spatial_size is None: - raise ValueError("For simple tensor inputs, `format` and `spatial_size` has to be passed.") - return _clamp_bounding_box(inpt, format=format, spatial_size=spatial_size) - elif isinstance(inpt, datapoints.BoundingBox): - if format is not None or spatial_size is not None: - raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.") - output = _clamp_bounding_box(inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size) - return datapoints.BoundingBox.wrap_like(inpt, output) - else: - raise TypeError( - f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." - ) - -def _num_value_bits(dtype: torch.dtype) -> int: - if dtype == torch.uint8: - return 8 - elif dtype == torch.int8: - return 7 - elif dtype == torch.int16: - return 15 - elif dtype == torch.int32: - return 31 - elif dtype == torch.int64: - return 63 - else: - raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") - - -def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: - if image.dtype == dtype: - return image - - float_input = image.is_floating_point() - if torch.jit.is_scripting(): - # TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT - float_output = torch.tensor(0, dtype=dtype).is_floating_point() - else: - float_output = dtype.is_floating_point - - if float_input: - # float to float - if float_output: - return image.to(dtype) - - # float to int - if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( - image.dtype == torch.float64 and dtype == torch.int64 - ): - raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.") - - # For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting - # to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only - # be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 - # for a detailed analysis. - # To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation. - # Instead, we can also multiply by the maximum value plus something close to `1`. See - # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details. - eps = 1e-3 - max_value = float(_max_value(dtype)) - # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the - # discrete set `{0, 1}`. - return image.mul(max_value + 1.0 - eps).to(dtype) - else: - # int to float - if float_output: - return image.to(dtype).mul_(1.0 / _max_value(image.dtype)) - - # int to int - num_value_bits_input = _num_value_bits(image.dtype) - num_value_bits_output = _num_value_bits(dtype) - - if num_value_bits_input > num_value_bits_output: - return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype) - else: - return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input) - - -# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is -# prevalent and well understood. Thus, we just alias it without deprecating the old name. -convert_image_dtype = convert_dtype_image_tensor - - -def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: - return convert_dtype_image_tensor(video, dtype) - - -def convert_dtype( - inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], dtype: torch.dtype = torch.float -) -> torch.Tensor: - if not torch.jit.is_scripting(): - _log_api_usage_once(convert_dtype) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return convert_dtype_image_tensor(inpt, dtype) - elif isinstance(inpt, datapoints.Image): - output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype) - return datapoints.Image.wrap_like(inpt, output) - elif isinstance(inpt, datapoints.Video): - output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype) - return datapoints.Video.wrap_like(inpt, output) + if format is None or canvas_size is None: + raise ValueError("For simple tensor inputs, `format` and `canvas_size` has to be passed.") + return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size) + elif isinstance(inpt, datapoints.BoundingBoxes): + if format is not None or canvas_size is not None: + raise ValueError("For bounding box datapoint inputs, `format` and `canvas_size` must not be passed.") + output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size) + return datapoints.BoundingBoxes.wrap_like(inpt, output) else: raise TypeError( - f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead." + f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." ) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 9abb3ac22ce..cda85ba906e 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -6,6 +6,7 @@ from torch.nn.functional import conv2d, pad as torch_pad from torchvision import datapoints +from torchvision.transforms._functional_tensor import _max_value from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once @@ -182,3 +183,97 @@ def gaussian_blur( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) + + +def _num_value_bits(dtype: torch.dtype) -> int: + if dtype == torch.uint8: + return 8 + elif dtype == torch.int8: + return 7 + elif dtype == torch.int16: + return 15 + elif dtype == torch.int32: + return 31 + elif dtype == torch.int64: + return 63 + else: + raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") + + +def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: + + if image.dtype == dtype: + return image + elif not scale: + return image.to(dtype) + + float_input = image.is_floating_point() + if torch.jit.is_scripting(): + # TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT + float_output = torch.tensor(0, dtype=dtype).is_floating_point() + else: + float_output = dtype.is_floating_point + + if float_input: + # float to float + if float_output: + return image.to(dtype) + + # float to int + if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( + image.dtype == torch.float64 and dtype == torch.int64 + ): + raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.") + + # For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting + # to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only + # be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 + # for a detailed analysis. + # To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation. + # Instead, we can also multiply by the maximum value plus something close to `1`. See + # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details. + eps = 1e-3 + max_value = float(_max_value(dtype)) + # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the + # discrete set `{0, 1}`. + return image.mul(max_value + 1.0 - eps).to(dtype) + else: + # int to float + if float_output: + return image.to(dtype).mul_(1.0 / _max_value(image.dtype)) + + # int to int + num_value_bits_input = _num_value_bits(image.dtype) + num_value_bits_output = _num_value_bits(dtype) + + if num_value_bits_input > num_value_bits_output: + return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype) + else: + return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input) + + +# We encourage users to use to_dtype() instead but we keep this for BC +def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: + return to_dtype_image_tensor(image, dtype=dtype, scale=True) + + +def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: + return to_dtype_image_tensor(video, dtype, scale=scale) + + +def to_dtype(inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: + if not torch.jit.is_scripting(): + _log_api_usage_once(to_dtype) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return to_dtype_image_tensor(inpt, dtype, scale=scale) + elif isinstance(inpt, datapoints.Image): + output = to_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype, scale=scale) + return datapoints.Image.wrap_like(inpt, output) + elif isinstance(inpt, datapoints.Video): + output = to_dtype_video(inpt.as_subclass(torch.Tensor), dtype, scale=scale) + return datapoints.Video.wrap_like(inpt, output) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.to(dtype) + else: + raise TypeError(f"Input can either be a plain tensor or a datapoint, but got {type(inpt)} instead.") diff --git a/torchvision/transforms/v2/utils.py b/torchvision/transforms/v2/utils.py index c4cf481bcd2..dd9f4489dee 100644 --- a/torchvision/transforms/v2/utils.py +++ b/torchvision/transforms/v2/utils.py @@ -6,15 +6,15 @@ from torchvision import datapoints from torchvision._utils import sequence_to_str -from torchvision.transforms.v2.functional import get_dimensions, get_spatial_size, is_simple_tensor +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor -def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox: - bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBox)] +def query_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes: + bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)] if not bounding_boxes: - raise TypeError("No bounding box was found in the sample") + raise TypeError("No bounding boxes were found in the sample") elif len(bounding_boxes) > 1: - raise ValueError("Found multiple bounding boxes in the sample") + raise ValueError("Found multiple bounding boxes instances in the sample") return bounding_boxes.pop() @@ -22,7 +22,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if isinstance(inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video)) or is_simple_tensor(inpt) + if check_type(inpt, (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -32,14 +32,21 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: return c, h, w -def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]: +def query_size(flat_inputs: List[Any]) -> Tuple[int, int]: sizes = { - tuple(get_spatial_size(inpt)) + tuple(get_size(inpt)) for inpt in flat_inputs - if isinstance( - inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video, datapoints.Mask, datapoints.BoundingBox) + if check_type( + inpt, + ( + is_simple_tensor, + datapoints.Image, + PIL.Image.Image, + datapoints.Video, + datapoints.Mask, + datapoints.BoundingBoxes, + ), ) - or is_simple_tensor(inpt) } if not sizes: raise TypeError("No image, video, mask or bounding box was found in the sample")