-
Notifications
You must be signed in to change notification settings - Fork 6.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite transforms v2 e2e example #7881
Changes from 5 commits
0887542
9d11700
958d652
fef649b
b4af363
f404f6c
f9b90b2
e799302
f719a60
980556d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
import matplotlib.pyplot as plt | ||
from torchvision.utils import draw_bounding_boxes | ||
import torch | ||
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks | ||
from torchvision import datapoints | ||
from torchvision.transforms.v2 import functional as F | ||
|
||
|
||
def plot(imgs): | ||
|
@@ -12,20 +15,30 @@ def plot(imgs): | |
_, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False) | ||
for row_idx, row in enumerate(imgs): | ||
for col_idx, img in enumerate(row): | ||
bboxes = None | ||
boxes = None | ||
masks = None | ||
if isinstance(img, tuple): | ||
bboxes = img[1] | ||
img = img[0] | ||
if isinstance(bboxes, dict): | ||
bboxes = bboxes['bboxes'] | ||
img, target = img | ||
if isinstance(target, dict): | ||
boxes = target.get("boxes") | ||
masks = target.get("masks") | ||
elif isinstance(target, datapoints.BoundingBoxes): | ||
boxes = target | ||
else: | ||
raise ValueError(f"Unexpected target type: {type(target)}") | ||
img = F.to_image(img) | ||
if img.dtype.is_floating_point and img.min() < 0: | ||
# Poor man's re-normalization for the colors to be OK-ish. This | ||
# is useful for images coming out of Normalize() | ||
img -= img.min() | ||
img /= img.max() | ||
|
||
if bboxes is not None: | ||
img = draw_bounding_boxes(img, bboxes, colors="yellow", width=3) | ||
img = F.to_dtype(img, torch.uint8, scale=True) | ||
if boxes is not None: | ||
img = draw_bounding_boxes(img, boxes, colors="yellow", width=3) | ||
if masks is not None: | ||
img = draw_segmentation_masks(img, masks.to(torch.bool), alpha=.65) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we maybe also make the masks yellow or some other color? I fell in the current version the masks is not completely obvious on the image. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmmm looks like the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe I was looking at the wrong output - it actually works fine and the masks are green (and very obvious) |
||
|
||
ax = axs[row_idx, col_idx] | ||
ax.imshow(img.permute(1, 2, 0).numpy()) | ||
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -90,7 +90,7 @@ | |||||||||||||||||||
|
||||||||||||||||||||
from torchvision import datapoints # we'll describe this a bit later, bare with us | ||||||||||||||||||||
|
||||||||||||||||||||
bboxes = datapoints.BoundingBoxes( | ||||||||||||||||||||
boxes = datapoints.BoundingBoxes( | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dataset wrapper uses "boxes" instead of "bboxes". I don't have a pref, but we should align (hence the changes in this example) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason the dataset wrapper uses vision/torchvision/models/detection/ssd.py Lines 333 to 341 in f7c7bdf
|
||||||||||||||||||||
[ | ||||||||||||||||||||
[15, 10, 370, 510], | ||||||||||||||||||||
[275, 340, 510, 510], | ||||||||||||||||||||
|
@@ -103,9 +103,10 @@ | |||||||||||||||||||
v2.RandomPhotometricDistort(p=1), | ||||||||||||||||||||
v2.RandomHorizontalFlip(p=1), | ||||||||||||||||||||
]) | ||||||||||||||||||||
out_img, out_bboxes = transforms(img, bboxes) | ||||||||||||||||||||
out_img, out_boxes = transforms(img, boxes) | ||||||||||||||||||||
print(type(boxes), type(out_boxes)) | ||||||||||||||||||||
|
||||||||||||||||||||
plot([(img, bboxes), (out_img, out_bboxes)]) | ||||||||||||||||||||
plot([(img, boxes), (out_img, out_boxes)]) | ||||||||||||||||||||
|
||||||||||||||||||||
# %% | ||||||||||||||||||||
# | ||||||||||||||||||||
|
@@ -119,6 +120,9 @@ | |||||||||||||||||||
# answer these in the next sections. | ||||||||||||||||||||
|
||||||||||||||||||||
# %% | ||||||||||||||||||||
# | ||||||||||||||||||||
# .. _what_are_datapoints: | ||||||||||||||||||||
# | ||||||||||||||||||||
# What are Datapoints? | ||||||||||||||||||||
# -------------------- | ||||||||||||||||||||
# | ||||||||||||||||||||
|
@@ -151,7 +155,7 @@ | |||||||||||||||||||
# | ||||||||||||||||||||
# Above, we've seen two examples: one where we passed a single image as input | ||||||||||||||||||||
# i.e. ``out = transforms(img)``, and one where we passed both an image and | ||||||||||||||||||||
# bounding boxes, i.e. ``out_img, out_bboxes = transforms(img, bboxes)``. | ||||||||||||||||||||
# bounding boxes, i.e. ``out_img, out_boxes = transforms(img, boxes)``. | ||||||||||||||||||||
# | ||||||||||||||||||||
# In fact, transforms support **arbitrary input structures**. The input can be a | ||||||||||||||||||||
# single image, a tuple, an arbitrarily nested dictionary... pretty much | ||||||||||||||||||||
|
@@ -160,15 +164,15 @@ | |||||||||||||||||||
# we're getting the same structure as output: | ||||||||||||||||||||
|
||||||||||||||||||||
target = { | ||||||||||||||||||||
"bboxes": bboxes, | ||||||||||||||||||||
"labels": torch.arange(bboxes.shape[0]), | ||||||||||||||||||||
"boxes": boxes, | ||||||||||||||||||||
"labels": torch.arange(boxes.shape[0]), | ||||||||||||||||||||
"this_is_ignored": ("arbitrary", {"structure": "!"}) | ||||||||||||||||||||
} | ||||||||||||||||||||
|
||||||||||||||||||||
# Re-using the transforms and definitions from above. | ||||||||||||||||||||
out_img, out_target = transforms(img, target) | ||||||||||||||||||||
|
||||||||||||||||||||
plot([(img, target["bboxes"]), (out_img, out_target["bboxes"])]) | ||||||||||||||||||||
plot([(img, target["boxes"]), (out_img, out_target["boxes"])]) | ||||||||||||||||||||
print(f"{out_target['this_is_ignored']}") | ||||||||||||||||||||
|
||||||||||||||||||||
# %% | ||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,146 +1,169 @@ | ||||||
""" | ||||||
Check warning on line 1 in gallery/v2_transforms/plot_transforms_v2_e2e.py GitHub Actions / bc
|
||||||
================================================== | ||||||
Transforms v2: End-to-end object detection example | ||||||
================================================== | ||||||
=============================================================== | ||||||
Transforms v2: End-to-end object detection/segmentation example | ||||||
=============================================================== | ||||||
|
||||||
.. note:: | ||||||
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms_v2_e2e.ipynb>`_ | ||||||
or :ref:`go to the end <sphx_glr_download_auto_examples_v2_transforms_plot_transforms_v2_e2e.py>` to download the full example code. | ||||||
|
||||||
Object detection is not supported out of the box by ``torchvision.transforms`` v1, since it only supports images. | ||||||
``torchvision.transforms.v2`` enables jointly transforming images, videos, bounding boxes, and masks. This example | ||||||
showcases an end-to-end object detection training using the stable ``torchvision.datasets`` and ``torchvision.models`` | ||||||
as well as the new ``torchvision.transforms.v2`` v2 API. | ||||||
Object detection and segmentation tasks are natively supported: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using "segmentation" loosely here (i.e. both instance and semantic). It's made more accurate a bit later in the example, the main point being that it doesn't matter which one you're targetting, the transforms work the same. |
||||||
``torchvision.transforms.v2`` enables jointly transforming images, videos, | ||||||
bounding boxes, and masks. | ||||||
|
||||||
This example showcases an end-to-end instance segmentation training case using | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine to keep as-is considering the sentence just below. Technically, it is an instance segmentation model. |
||||||
Torchvision utils from ``torchvision.datasets``, ``torchvision.models`` and | ||||||
``torchvision.transforms.v2``. Everything covered here can be applied similarly | ||||||
to instance detection or semantic segmentation tasks. | ||||||
""" | ||||||
|
||||||
# %% | ||||||
import pathlib | ||||||
|
||||||
import PIL.Image | ||||||
|
||||||
import torch | ||||||
import torch.utils.data | ||||||
|
||||||
from torchvision import models, datasets | ||||||
import torchvision.transforms.v2 as transforms | ||||||
|
||||||
|
||||||
def show(sample): | ||||||
import matplotlib.pyplot as plt | ||||||
|
||||||
from torchvision.transforms.v2 import functional as F | ||||||
from torchvision.utils import draw_bounding_boxes | ||||||
|
||||||
image, target = sample | ||||||
if isinstance(image, PIL.Image.Image): | ||||||
image = F.to_image(image) | ||||||
image = F.to_dtype(image, torch.uint8, scale=True) | ||||||
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3) | ||||||
from torchvision import models, datasets, datapoints | ||||||
from torchvision.transforms import v2 | ||||||
|
||||||
fig, ax = plt.subplots() | ||||||
ax.imshow(annotated_image.permute(1, 2, 0).numpy()) | ||||||
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) | ||||||
fig.tight_layout() | ||||||
torch.manual_seed(0) | ||||||
|
||||||
fig.show() | ||||||
# This loads fake data for illustration purposes of this example. In practice, you'll have | ||||||
# to replace this with the proper data. | ||||||
# If you're trying to run that on collab, you can download the assets and the | ||||||
# helpers from https://github.com/pytorch/vision/tree/main/gallery/ | ||||||
ROOT = pathlib.Path("../assets") / "coco" | ||||||
IMAGES_PATH = str(ROOT / "images") | ||||||
ANNOTATIONS_PATH = str(ROOT / "instances.json") | ||||||
from helpers import plot | ||||||
|
||||||
|
||||||
# %% | ||||||
# Dataset preparation | ||||||
# ------------------- | ||||||
# | ||||||
# We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently | ||||||
# returns, and we'll see how to convert it to a format that is compatible with our new transforms. | ||||||
|
||||||
def load_example_coco_detection_dataset(**kwargs): | ||||||
# This loads fake data for illustration purposes of this example. In practice, you'll have | ||||||
# to replace this with the proper data | ||||||
root = pathlib.Path("../assets") / "coco" | ||||||
return datasets.CocoDetection(str(root / "images"), str(root / "instances.json"), **kwargs) | ||||||
|
||||||
# returns. | ||||||
|
||||||
dataset = load_example_coco_detection_dataset() | ||||||
dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH) | ||||||
|
||||||
sample = dataset[0] | ||||||
image, target = sample | ||||||
print(type(image)) | ||||||
print(type(target), type(target[0]), list(target[0].keys())) | ||||||
img, target = sample | ||||||
print(f"{type(img) = }\n{type(target) = }\n{type(target[0]) = }\n{target[0].keys() = }") | ||||||
|
||||||
|
||||||
# %% | ||||||
# The dataset returns a two-tuple with the first item being a :class:`PIL.Image.Image` and second one a list of | ||||||
# dictionaries, which each containing the annotations for a single object instance. As is, this format is not compatible | ||||||
# with the ``torchvision.transforms.v2``, nor with the models. To overcome that, we provide the | ||||||
# Torchvision datasets preserve the data structure and types as it was intended | ||||||
# by the datasets authors. So by default, the output structure may not always be | ||||||
# compatible with the models or the transforms. | ||||||
# | ||||||
# To overcome that, we can use the | ||||||
# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For | ||||||
# :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It | ||||||
# also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding | ||||||
# ``torchvision.datapoints``. By default, it only returns ``"boxes"`` and ``"labels"`` to avoid transforming unnecessary | ||||||
# items down the line, but you can pass the ``target_type`` parameter for fine-grained control. | ||||||
# :class:`~torchvision.datasets.CocoDetection`, this changes the target | ||||||
# structure to a single dictionary of lists: | ||||||
|
||||||
dataset = datasets.wrap_dataset_for_transforms_v2(dataset) | ||||||
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=("boxes", "labels", "masks")) | ||||||
|
||||||
sample = dataset[0] | ||||||
image, target = sample | ||||||
print(type(image)) | ||||||
print(type(target), list(target.keys())) | ||||||
print(type(target["boxes"]), type(target["labels"])) | ||||||
img, target = sample | ||||||
print(f"{type(img) = }\n{type(target) = }\n{target.keys() = }") | ||||||
print(f"{type(target['boxes']) = }\n{type(target['labels']) = }\n{type(target['masks']) = }") | ||||||
|
||||||
# We used the ``target_keys`` parameter to specify the kind of output we're | ||||||
# interested in. Our dataset now returns a target which is dict where the values | ||||||
# are :ref:`Datapoints <what_are_datapoints>` (all are :class:`torch.Tensor` | ||||||
# subclasses). We're dropped all unncessary keys from the previous output, but | ||||||
# if you need any of the original keys e.g. "image_id", you can still ask for | ||||||
# it. | ||||||
# | ||||||
# .. note:: | ||||||
# | ||||||
# If you just want to do detection, you don't need and shouldn't pass | ||||||
# "masks" in ``target_keys``: if masks are present in the sample, they will | ||||||
# be transformed, slowing down your transformations unnecessarily. | ||||||
|
||||||
|
||||||
# %% | ||||||
# As baseline, let's have a look at a sample without transformations: | ||||||
|
||||||
show(sample) | ||||||
plot([dataset[0], dataset[1]]) | ||||||
|
||||||
|
||||||
# %% | ||||||
# With the dataset properly set up, we can now define the augmentation pipeline. This is done the same way it is done in | ||||||
# ``torchvision.transforms`` v1, but now handles bounding boxes and masks without any extra configuration. | ||||||
# Transforms | ||||||
# ---------- | ||||||
# | ||||||
# Let's now define our pre-processing transforms. All the transforms know how | ||||||
# to handle images, bouding boxes and masks when relevant. | ||||||
# | ||||||
# Transforms are typically passed as the ``transforms`` parameter of the | ||||||
# dataset so that they can leverage multi-processing from the | ||||||
# :class:`torch.utils.data.DataLoader`. | ||||||
|
||||||
transform = transforms.Compose( | ||||||
transforms = v2.Compose( | ||||||
[ | ||||||
transforms.RandomPhotometricDistort(), | ||||||
transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}), | ||||||
transforms.RandomIoUCrop(), | ||||||
transforms.RandomHorizontalFlip(), | ||||||
transforms.ToImage(), | ||||||
transforms.ConvertImageDtype(torch.float32), | ||||||
transforms.SanitizeBoundingBoxes(), | ||||||
v2.ToImage(), | ||||||
v2.RandomPhotometricDistort(p=1), | ||||||
v2.RandomZoomOut(fill={datapoints.Image: (123, 117, 104), "others": 0}), | ||||||
v2.RandomIoUCrop(), | ||||||
v2.RandomHorizontalFlip(p=1), | ||||||
v2.SanitizeBoundingBoxes(), | ||||||
v2.ToDtype(torch.float32, scale=True), | ||||||
] | ||||||
) | ||||||
|
||||||
dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH, transforms=transforms) | ||||||
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=["boxes", "labels", "masks"]) | ||||||
|
||||||
# %% | ||||||
# .. note:: | ||||||
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it | ||||||
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as | ||||||
# the corresponding labels and optionally masks. It is particularly critical to add it if | ||||||
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used. | ||||||
# A few things are worth noting here: | ||||||
# | ||||||
# - We're converting the PIL image into a | ||||||
# :class:`~torchvision.transforms.v2.Image` object. This isn't strictly | ||||||
# necessary, but relying on Tensors (here: a Tensor subclass) will | ||||||
# :ref:`generally be faster <transforms_perf>`. | ||||||
# - We are calling :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` to | ||||||
# make sure we remove degenerate bounding boxes, as well as their | ||||||
# corresponding labels and masks. | ||||||
# :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` should be placed | ||||||
# at least once at the end of a detection pipeline; it is particularly | ||||||
# critical if :class:`~torchvision.transforms.v2.RandomIoUCrop` was used. | ||||||
# | ||||||
# Let's look how the sample looks like with our augmentation pipeline in place: | ||||||
|
||||||
dataset = load_example_coco_detection_dataset(transforms=transform) | ||||||
dataset = datasets.wrap_dataset_for_transforms_v2(dataset) | ||||||
|
||||||
torch.manual_seed(3141) | ||||||
sample = dataset[0] | ||||||
|
||||||
# sphinx_gallery_thumbnail_number = 2 | ||||||
show(sample) | ||||||
plot([dataset[0], dataset[1]]) | ||||||
|
||||||
|
||||||
# %% | ||||||
# We can see that the color of the image was distorted, we zoomed out on it (off center) and flipped it horizontally. | ||||||
# In all of this, the bounding box was transformed accordingly. And without any further ado, we can start training. | ||||||
# We can see that the color of the images were distorted, zoomed in or out, and flipped. | ||||||
# The bounding boxes and the masks were transformed accordingly. And without any further ado, we can start training. | ||||||
# | ||||||
# Data loading and training loop | ||||||
# ------------------------------ | ||||||
# | ||||||
# Below we're using Mask-RCNN which is an instance segmentation model, but | ||||||
# everything we've covered in this tutorial also applies to object detection and | ||||||
# semantic segmentation tasks. | ||||||
|
||||||
data_loader = torch.utils.data.DataLoader( | ||||||
dataset, | ||||||
batch_size=2, | ||||||
# We need a custom collation function here, since the object detection models expect a | ||||||
# sequence of images and target dictionaries. The default collation function tries to | ||||||
# `torch.stack` the individual elements, which fails in general for object detection, | ||||||
# because the number of object instances varies between the samples. This is the same for | ||||||
# `torchvision.transforms` v1 | ||||||
# We need a custom collation function here, since the object detection | ||||||
# models expect a sequence of images and target dictionaries. The default | ||||||
# collation function tries to :func:`~torch.stack` the individual elements, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rst formatting is useless here as this is a code comment (will self-address later)
Suggested change
|
||||||
# which fails in general for object detection, because the number of bouding | ||||||
# boxes varies between the images of a same batch. | ||||||
collate_fn=lambda batch: tuple(zip(*batch)), | ||||||
) | ||||||
|
||||||
model = models.get_model("ssd300_vgg16", weights=None, weights_backbone=None).train() | ||||||
model = models.get_model("maskrcnn_resnet50_fpn_v2", weights=None, weights_backbone=None).train() | ||||||
|
||||||
for images, targets in data_loader: | ||||||
loss_dict = model(images, targets) | ||||||
print(loss_dict) | ||||||
for imgs, targets in data_loader: | ||||||
loss_dict = model(imgs, targets) | ||||||
# Put your training logic here | ||||||
break | ||||||
|
||||||
print(f"{[img.shape for img in imgs] = }") | ||||||
print(f"{[type(target) for target in targets] = }") | ||||||
for name, loss_val in loss_dict.items(): | ||||||
print(f"{name:<20}{loss_val:.3f}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes in this file are only needed to support masks