Skip to content

Commit

Permalink
ajenaejgng
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Aug 24, 2023
1 parent fef649b commit b4af363
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 29 deletions.
74 changes: 46 additions & 28 deletions gallery/v2_transforms/plot_transforms_v2_e2e.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
"""

Check warning on line 1 in gallery/v2_transforms/plot_transforms_v2_e2e.py

View workflow job for this annotation

GitHub Actions / bc

Function show: function deleted

Check warning on line 1 in gallery/v2_transforms/plot_transforms_v2_e2e.py

View workflow job for this annotation

GitHub Actions / bc

Function load_example_coco_detection_dataset: function deleted
==================================================
Transforms v2: End-to-end object detection example
==================================================
===============================================================
Transforms v2: End-to-end object detection/segmentation example
===============================================================
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms_v2_e2e.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_v2_transforms_plot_transforms_v2_e2e.py>` to download the full example code.
Object detection is not supported out of the box by ``torchvision.transforms`` v1, since it only supports images.
``torchvision.transforms.v2`` enables jointly transforming images, videos, bounding boxes, and masks. This example
showcases an end-to-end object detection training using the stable ``torchvision.datasets`` and ``torchvision.models``
as well as the new ``torchvision.transforms.v2`` v2 API.
Object detection and segmentation tasks are natively supported:
``torchvision.transforms.v2`` enables jointly transforming images, videos,
bounding boxes, and masks.
This example showcases an end-to-end instance segmentation training case using
Torchvision utils from ``torchvision.datasets``, ``torchvision.models`` and
``torchvision.transforms.v2``. Everything covered here can be applied similarly
to instance detection or semantic segmentation tasks.
"""

# %%
import pathlib

import torch
import torch.utils.data
import PIL.Image

from torchvision import models, datasets, datapoints
from torchvision.transforms import v2

torch.manual_seed(1)
torch.manual_seed(0)

# This loads fake data for illustration purposes of this example. In practice, you'll have
# to replace this with the proper data.
Expand All @@ -36,6 +39,9 @@


# %%
# Dataset preparation
# -------------------
#
# We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently
# returns.

Expand Down Expand Up @@ -84,6 +90,9 @@


# %%
# Transforms
# ----------
#
# Let's now define our pre-processing transforms. All the transforms know how
# to handle images, bouding boxes and masks when relevant.
#
Expand All @@ -94,13 +103,12 @@
transforms = v2.Compose(
[
v2.ToImage(),
v2.RandomPhotometricDistort(),
v2.RandomPhotometricDistort(p=1),
v2.RandomZoomOut(fill={datapoints.Image: (123, 117, 104), "others": 0}),
v2.RandomIoUCrop(),
v2.RandomHorizontalFlip(),
v2.RandomHorizontalFlip(p=1),
v2.SanitizeBoundingBoxes(),
v2.ToDtype(torch.float32, scale=True),
v2.ToPureTensor(),
]
)

Expand All @@ -114,38 +122,48 @@
# :class:`~torchvision.transforms.v2.Image` object. This isn't strictly
# necessary, but relying on Tensors (here: a Tensor subclass) will
# :ref:`generally be faster <transforms_perf>`.
# - Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it
# should be placed at least once at the end of a detection pipeline to remove
# degenerate bounding boxes as well as the corresponding labels and optionally
# masks. It is particularly critical to add it if
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
# - We are calling :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` to
# make sure we remove degenerate bounding boxes, as well as their
# corresponding labels and masks.
# :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` should be placed
# at least once at the end of a detection pipeline; it is particularly
# critical if :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
#
# Let's look how the sample looks like with our augmentation pipeline in place:


# sphinx_gallery_thumbnail_number = 2
plot([dataset[0], dataset[1]])


# %%
# We can see that the color of the image was distorted, we zoomed out on it (off center) and flipped it horizontally.
# In all of this, the bounding box was transformed accordingly. And without any further ado, we can start training.
# We can see that the color of the images were distorted, zoomed in or out, and flipped.
# The bounding boxes and the masks were transformed accordingly. And without any further ado, we can start training.
#
# Data loading and training loop
# ------------------------------
#
# Below we're using Mask-RCNN which is an instance segmentation model, but
# everything we've covered in this tutorial also applies to object detection and
# semantic segmentation tasks.

data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
# We need a custom collation function here, since the object detection models expect a
# sequence of images and target dictionaries. The default collation function tries to
# `torch.stack` the individual elements, which fails in general for object detection,
# because the number of object instances varies between the samples. This is the same for
# `torchvision.transforms` v1
# We need a custom collation function here, since the object detection
# models expect a sequence of images and target dictionaries. The default
# collation function tries to :func:`~torch.stack` the individual elements,
# which fails in general for object detection, because the number of bouding
# boxes varies between the images of a same batch.
collate_fn=lambda batch: tuple(zip(*batch)),
)

model = models.get_model("ssd300_vgg16", weights=None, weights_backbone=None).train()
model = models.get_model("maskrcnn_resnet50_fpn_v2", weights=None, weights_backbone=None).train()

for imgs, targets in data_loader:
loss_dict = model(imgs, targets)
print(loss_dict)
# Put your training logic here
break

print(f"{[img.shape for img in imgs] = }")
print(f"{[type(target) for target in targets] = }")
for name, loss_val in loss_dict.items():
print(f"{name:<20}{loss_val:.3f}")
1 change: 0 additions & 1 deletion torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,6 @@ def forward(self, *inputs: Any) -> Any:
new_format=datapoints.BoundingBoxFormat.XYXY,
),
)
print(f"{type(boxes) =}")
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All
Expand Down

0 comments on commit b4af363

Please sign in to comment.