Skip to content

Commit

Permalink
Rewrite transforms v2 e2e example (#7881)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Aug 25, 2023
1 parent 92e4e9c commit 224cbc8
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 103 deletions.
29 changes: 21 additions & 8 deletions gallery/v2_transforms/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import datapoints
from torchvision.transforms.v2 import functional as F


def plot(imgs):
Expand All @@ -12,20 +15,30 @@ def plot(imgs):
_, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
for row_idx, row in enumerate(imgs):
for col_idx, img in enumerate(row):
bboxes = None
boxes = None
masks = None
if isinstance(img, tuple):
bboxes = img[1]
img = img[0]
if isinstance(bboxes, dict):
bboxes = bboxes['bboxes']
img, target = img
if isinstance(target, dict):
boxes = target.get("boxes")
masks = target.get("masks")
elif isinstance(target, datapoints.BoundingBoxes):
boxes = target
else:
raise ValueError(f"Unexpected target type: {type(target)}")
img = F.to_image(img)
if img.dtype.is_floating_point and img.min() < 0:
# Poor man's re-normalization for the colors to be OK-ish. This
# is useful for images coming out of Normalize()
img -= img.min()
img /= img.max()

if bboxes is not None:
img = draw_bounding_boxes(img, bboxes, colors="yellow", width=3)
img = F.to_dtype(img, torch.uint8, scale=True)
if boxes is not None:
img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
if masks is not None:
img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)

ax = axs[row_idx, col_idx]
ax.imshow(img.permute(1, 2, 0).numpy())
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
Expand Down
18 changes: 11 additions & 7 deletions gallery/v2_transforms/plot_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@

from torchvision import datapoints # we'll describe this a bit later, bare with us

bboxes = datapoints.BoundingBoxes(
boxes = datapoints.BoundingBoxes(
[
[15, 10, 370, 510],
[275, 340, 510, 510],
Expand All @@ -103,9 +103,10 @@
v2.RandomPhotometricDistort(p=1),
v2.RandomHorizontalFlip(p=1),
])
out_img, out_bboxes = transforms(img, bboxes)
out_img, out_boxes = transforms(img, boxes)
print(type(boxes), type(out_boxes))

plot([(img, bboxes), (out_img, out_bboxes)])
plot([(img, boxes), (out_img, out_boxes)])

# %%
#
Expand All @@ -119,6 +120,9 @@
# answer these in the next sections.

# %%
#
# .. _what_are_datapoints:
#
# What are Datapoints?
# --------------------
#
Expand Down Expand Up @@ -151,7 +155,7 @@
#
# Above, we've seen two examples: one where we passed a single image as input
# i.e. ``out = transforms(img)``, and one where we passed both an image and
# bounding boxes, i.e. ``out_img, out_bboxes = transforms(img, bboxes)``.
# bounding boxes, i.e. ``out_img, out_boxes = transforms(img, boxes)``.
#
# In fact, transforms support **arbitrary input structures**. The input can be a
# single image, a tuple, an arbitrarily nested dictionary... pretty much
Expand All @@ -160,15 +164,15 @@
# we're getting the same structure as output:

target = {
"bboxes": bboxes,
"labels": torch.arange(bboxes.shape[0]),
"boxes": boxes,
"labels": torch.arange(boxes.shape[0]),
"this_is_ignored": ("arbitrary", {"structure": "!"})
}

# Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target)

plot([(img, target["bboxes"]), (out_img, out_target["bboxes"])])
plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
print(f"{out_target['this_is_ignored']}")

# %%
Expand Down
198 changes: 110 additions & 88 deletions gallery/v2_transforms/plot_transforms_v2_e2e.py
Original file line number Diff line number Diff line change
@@ -1,146 +1,168 @@
"""
==================================================
Transforms v2: End-to-end object detection example
==================================================
===============================================================
Transforms v2: End-to-end object detection/segmentation example
===============================================================
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms_v2_e2e.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_v2_transforms_plot_transforms_v2_e2e.py>` to download the full example code.
Object detection is not supported out of the box by ``torchvision.transforms`` v1, since it only supports images.
``torchvision.transforms.v2`` enables jointly transforming images, videos, bounding boxes, and masks. This example
showcases an end-to-end object detection training using the stable ``torchvision.datasets`` and ``torchvision.models``
as well as the new ``torchvision.transforms.v2`` v2 API.
Object detection and segmentation tasks are natively supported:
``torchvision.transforms.v2`` enables jointly transforming images, videos,
bounding boxes, and masks.
This example showcases an end-to-end instance segmentation training case using
Torchvision utils from ``torchvision.datasets``, ``torchvision.models`` and
``torchvision.transforms.v2``. Everything covered here can be applied similarly
to object detection or semantic segmentation tasks.
"""

# %%
import pathlib

import PIL.Image

import torch
import torch.utils.data

from torchvision import models, datasets
import torchvision.transforms.v2 as transforms


def show(sample):
import matplotlib.pyplot as plt

from torchvision.transforms.v2 import functional as F
from torchvision.utils import draw_bounding_boxes

image, target = sample
if isinstance(image, PIL.Image.Image):
image = F.to_image(image)
image = F.to_dtype(image, torch.uint8, scale=True)
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
from torchvision import models, datasets, datapoints
from torchvision.transforms import v2

fig, ax = plt.subplots()
ax.imshow(annotated_image.permute(1, 2, 0).numpy())
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
fig.tight_layout()
torch.manual_seed(0)

fig.show()
# This loads fake data for illustration purposes of this example. In practice, you'll have
# to replace this with the proper data.
# If you're trying to run that on collab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
ROOT = pathlib.Path("../assets") / "coco"
IMAGES_PATH = str(ROOT / "images")
ANNOTATIONS_PATH = str(ROOT / "instances.json")
from helpers import plot


# %%
# Dataset preparation
# -------------------
#
# We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently
# returns, and we'll see how to convert it to a format that is compatible with our new transforms.

def load_example_coco_detection_dataset(**kwargs):
# This loads fake data for illustration purposes of this example. In practice, you'll have
# to replace this with the proper data
root = pathlib.Path("../assets") / "coco"
return datasets.CocoDetection(str(root / "images"), str(root / "instances.json"), **kwargs)

# returns.

dataset = load_example_coco_detection_dataset()
dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH)

sample = dataset[0]
image, target = sample
print(type(image))
print(type(target), type(target[0]), list(target[0].keys()))
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{type(target[0]) = }\n{target[0].keys() = }")


# %%
# The dataset returns a two-tuple with the first item being a :class:`PIL.Image.Image` and second one a list of
# dictionaries, which each containing the annotations for a single object instance. As is, this format is not compatible
# with the ``torchvision.transforms.v2``, nor with the models. To overcome that, we provide the
# Torchvision datasets preserve the data structure and types as it was intended
# by the datasets authors. So by default, the output structure may not always be
# compatible with the models or the transforms.
#
# To overcome that, we can use the
# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
# :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It
# also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding
# ``torchvision.datapoints``. By default, it only returns ``"boxes"`` and ``"labels"`` to avoid transforming unnecessary
# items down the line, but you can pass the ``target_type`` parameter for fine-grained control.
# :class:`~torchvision.datasets.CocoDetection`, this changes the target
# structure to a single dictionary of lists:

dataset = datasets.wrap_dataset_for_transforms_v2(dataset)
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=("boxes", "labels", "masks"))

sample = dataset[0]
image, target = sample
print(type(image))
print(type(target), list(target.keys()))
print(type(target["boxes"]), type(target["labels"]))
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{target.keys() = }")
print(f"{type(target['boxes']) = }\n{type(target['labels']) = }\n{type(target['masks']) = }")

# %%
# We used the ``target_keys`` parameter to specify the kind of output we're
# interested in. Our dataset now returns a target which is dict where the values
# are :ref:`Datapoints <what_are_datapoints>` (all are :class:`torch.Tensor`
# subclasses). We're dropped all unncessary keys from the previous output, but
# if you need any of the original keys e.g. "image_id", you can still ask for
# it.
#
# .. note::
#
# If you just want to do detection, you don't need and shouldn't pass
# "masks" in ``target_keys``: if masks are present in the sample, they will
# be transformed, slowing down your transformations unnecessarily.
#
# As baseline, let's have a look at a sample without transformations:

show(sample)
plot([dataset[0], dataset[1]])


# %%
# With the dataset properly set up, we can now define the augmentation pipeline. This is done the same way it is done in
# ``torchvision.transforms`` v1, but now handles bounding boxes and masks without any extra configuration.
# Transforms
# ----------
#
# Let's now define our pre-processing transforms. All the transforms know how
# to handle images, bouding boxes and masks when relevant.
#
# Transforms are typically passed as the ``transforms`` parameter of the
# dataset so that they can leverage multi-processing from the
# :class:`torch.utils.data.DataLoader`.

transform = transforms.Compose(
transforms = v2.Compose(
[
transforms.RandomPhotometricDistort(),
transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}),
transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(),
transforms.ToImage(),
transforms.ConvertImageDtype(torch.float32),
transforms.SanitizeBoundingBoxes(),
v2.ToImage(),
v2.RandomPhotometricDistort(p=1),
v2.RandomZoomOut(fill={datapoints.Image: (123, 117, 104), "others": 0}),
v2.RandomIoUCrop(),
v2.RandomHorizontalFlip(p=1),
v2.SanitizeBoundingBoxes(),
v2.ToDtype(torch.float32, scale=True),
]
)

dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH, transforms=transforms)
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=["boxes", "labels", "masks"])

# %%
# .. note::
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as
# the corresponding labels and optionally masks. It is particularly critical to add it if
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
# A few things are worth noting here:
#
# - We're converting the PIL image into a
# :class:`~torchvision.transforms.v2.Image` object. This isn't strictly
# necessary, but relying on Tensors (here: a Tensor subclass) will
# :ref:`generally be faster <transforms_perf>`.
# - We are calling :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` to
# make sure we remove degenerate bounding boxes, as well as their
# corresponding labels and masks.
# :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` should be placed
# at least once at the end of a detection pipeline; it is particularly
# critical if :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
#
# Let's look how the sample looks like with our augmentation pipeline in place:

dataset = load_example_coco_detection_dataset(transforms=transform)
dataset = datasets.wrap_dataset_for_transforms_v2(dataset)

torch.manual_seed(3141)
sample = dataset[0]

# sphinx_gallery_thumbnail_number = 2
show(sample)
plot([dataset[0], dataset[1]])


# %%
# We can see that the color of the image was distorted, we zoomed out on it (off center) and flipped it horizontally.
# In all of this, the bounding box was transformed accordingly. And without any further ado, we can start training.
# We can see that the color of the images were distorted, zoomed in or out, and flipped.
# The bounding boxes and the masks were transformed accordingly. And without any further ado, we can start training.
#
# Data loading and training loop
# ------------------------------
#
# Below we're using Mask-RCNN which is an instance segmentation model, but
# everything we've covered in this tutorial also applies to object detection and
# semantic segmentation tasks.

data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
# We need a custom collation function here, since the object detection models expect a
# sequence of images and target dictionaries. The default collation function tries to
# `torch.stack` the individual elements, which fails in general for object detection,
# because the number of object instances varies between the samples. This is the same for
# `torchvision.transforms` v1
# We need a custom collation function here, since the object detection
# models expect a sequence of images and target dictionaries. The default
# collation function tries to torch.stack() the individual elements,
# which fails in general for object detection, because the number of bouding
# boxes varies between the images of a same batch.
collate_fn=lambda batch: tuple(zip(*batch)),
)

model = models.get_model("ssd300_vgg16", weights=None, weights_backbone=None).train()
model = models.get_model("maskrcnn_resnet50_fpn_v2", weights=None, weights_backbone=None).train()

for images, targets in data_loader:
loss_dict = model(images, targets)
print(loss_dict)
for imgs, targets in data_loader:
loss_dict = model(imgs, targets)
# Put your training logic here
break

print(f"{[img.shape for img in imgs] = }")
print(f"{[type(target) for target in targets] = }")
for name, loss_val in loss_dict.items():
print(f"{name:<20}{loss_val:.3f}")

0 comments on commit 224cbc8

Please sign in to comment.