From b4af363c5f26014e1648dd2a0c4870ed03736fd3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Aug 2023 15:52:56 +0100 Subject: [PATCH] ajenaejgng --- .../v2_transforms/plot_transforms_v2_e2e.py | 74 ++++++++++++------- torchvision/transforms/v2/_misc.py | 1 - 2 files changed, 46 insertions(+), 29 deletions(-) diff --git a/gallery/v2_transforms/plot_transforms_v2_e2e.py b/gallery/v2_transforms/plot_transforms_v2_e2e.py index 6b556ae0e75..cb9843ce42b 100644 --- a/gallery/v2_transforms/plot_transforms_v2_e2e.py +++ b/gallery/v2_transforms/plot_transforms_v2_e2e.py @@ -1,16 +1,20 @@ """ -================================================== -Transforms v2: End-to-end object detection example -================================================== +=============================================================== +Transforms v2: End-to-end object detection/segmentation example +=============================================================== .. note:: Try on `collab `_ or :ref:`go to the end ` to download the full example code. -Object detection is not supported out of the box by ``torchvision.transforms`` v1, since it only supports images. -``torchvision.transforms.v2`` enables jointly transforming images, videos, bounding boxes, and masks. This example -showcases an end-to-end object detection training using the stable ``torchvision.datasets`` and ``torchvision.models`` -as well as the new ``torchvision.transforms.v2`` v2 API. +Object detection and segmentation tasks are natively supported: +``torchvision.transforms.v2`` enables jointly transforming images, videos, +bounding boxes, and masks. + +This example showcases an end-to-end instance segmentation training case using +Torchvision utils from ``torchvision.datasets``, ``torchvision.models`` and +``torchvision.transforms.v2``. Everything covered here can be applied similarly +to instance detection or semantic segmentation tasks. """ # %% @@ -18,12 +22,11 @@ import torch import torch.utils.data -import PIL.Image from torchvision import models, datasets, datapoints from torchvision.transforms import v2 -torch.manual_seed(1) +torch.manual_seed(0) # This loads fake data for illustration purposes of this example. In practice, you'll have # to replace this with the proper data. @@ -36,6 +39,9 @@ # %% +# Dataset preparation +# ------------------- +# # We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently # returns. @@ -84,6 +90,9 @@ # %% +# Transforms +# ---------- +# # Let's now define our pre-processing transforms. All the transforms know how # to handle images, bouding boxes and masks when relevant. # @@ -94,13 +103,12 @@ transforms = v2.Compose( [ v2.ToImage(), - v2.RandomPhotometricDistort(), + v2.RandomPhotometricDistort(p=1), v2.RandomZoomOut(fill={datapoints.Image: (123, 117, 104), "others": 0}), v2.RandomIoUCrop(), - v2.RandomHorizontalFlip(), + v2.RandomHorizontalFlip(p=1), v2.SanitizeBoundingBoxes(), v2.ToDtype(torch.float32, scale=True), - v2.ToPureTensor(), ] ) @@ -114,38 +122,48 @@ # :class:`~torchvision.transforms.v2.Image` object. This isn't strictly # necessary, but relying on Tensors (here: a Tensor subclass) will # :ref:`generally be faster `. -# - Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it -# should be placed at least once at the end of a detection pipeline to remove -# degenerate bounding boxes as well as the corresponding labels and optionally -# masks. It is particularly critical to add it if -# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used. +# - We are calling :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` to +# make sure we remove degenerate bounding boxes, as well as their +# corresponding labels and masks. +# :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` should be placed +# at least once at the end of a detection pipeline; it is particularly +# critical if :class:`~torchvision.transforms.v2.RandomIoUCrop` was used. # # Let's look how the sample looks like with our augmentation pipeline in place: - # sphinx_gallery_thumbnail_number = 2 plot([dataset[0], dataset[1]]) # %% -# We can see that the color of the image was distorted, we zoomed out on it (off center) and flipped it horizontally. -# In all of this, the bounding box was transformed accordingly. And without any further ado, we can start training. +# We can see that the color of the images were distorted, zoomed in or out, and flipped. +# The bounding boxes and the masks were transformed accordingly. And without any further ado, we can start training. +# +# Data loading and training loop +# ------------------------------ +# +# Below we're using Mask-RCNN which is an instance segmentation model, but +# everything we've covered in this tutorial also applies to object detection and +# semantic segmentation tasks. data_loader = torch.utils.data.DataLoader( dataset, batch_size=2, - # We need a custom collation function here, since the object detection models expect a - # sequence of images and target dictionaries. The default collation function tries to - # `torch.stack` the individual elements, which fails in general for object detection, - # because the number of object instances varies between the samples. This is the same for - # `torchvision.transforms` v1 + # We need a custom collation function here, since the object detection + # models expect a sequence of images and target dictionaries. The default + # collation function tries to :func:`~torch.stack` the individual elements, + # which fails in general for object detection, because the number of bouding + # boxes varies between the images of a same batch. collate_fn=lambda batch: tuple(zip(*batch)), ) -model = models.get_model("ssd300_vgg16", weights=None, weights_backbone=None).train() +model = models.get_model("maskrcnn_resnet50_fpn_v2", weights=None, weights_backbone=None).train() for imgs, targets in data_loader: loss_dict = model(imgs, targets) - print(loss_dict) # Put your training logic here - break + + print(f"{[img.shape for img in imgs] = }") + print(f"{[type(target) for target in targets] = }") + for name, loss_val in loss_dict.items(): + print(f"{name:<20}{loss_val:.3f}") diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 171d371457b..79255b2256f 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -395,7 +395,6 @@ def forward(self, *inputs: Any) -> Any: new_format=datapoints.BoundingBoxFormat.XYXY, ), ) - print(f"{type(boxes) =}") ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) # TODO: Do we really need to check for out of bounds here? All