Skip to content

Commit

Permalink
Merge branch 'main' into add_mps_kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Aug 1, 2023
2 parents b1cf619 + 3e4e353 commit 3f82ee4
Show file tree
Hide file tree
Showing 70 changed files with 3,069 additions and 2,248 deletions.
2 changes: 2 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ d367a01a18a3ae6bee13d8be3b63fd6a581ea46f
6ca9c76adb6daf2695d603ad623a9cf1c4f4806f
# Fix unnecessary exploded black formatting (#7709)
a335d916db0694770e8152f41e19195de3134523
# Renaming: `BoundingBox` -> `BoundingBoxes` (#7778)
332bff937c6711666191880fab57fa2f23ae772e
2 changes: 1 addition & 1 deletion docs/source/datapoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
Image
Video
BoundingBoxFormat
BoundingBox
BoundingBoxes
Mask
21 changes: 18 additions & 3 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ Miscellaneous
v2.RandomErasing
Lambda
v2.Lambda
v2.SanitizeBoundingBox
v2.ClampBoundingBox
v2.SanitizeBoundingBoxes
v2.ClampBoundingBoxes
v2.UniformTemporalSubsample

.. _conversion_transforms:
Expand All @@ -234,7 +234,6 @@ Conversion
v2.PILToTensor
v2.ToImageTensor
ConvertImageDtype
v2.ConvertDtype
v2.ConvertImageDtype
v2.ToDtype
v2.ConvertBoundingBoxFormat
Expand Down Expand Up @@ -262,6 +261,22 @@ The new transform can be used standalone or mixed-and-matched with existing tran
AugMix
v2.AugMix

CutMix - MixUp
--------------

CutMix and MixUp are special transforms that
are meant to be used on batches rather than on individual images, because they
are combining pairs of images together. These can be used after the dataloader
(once the samples are batched), or part of a collation function. See
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.

.. autosummary::
:toctree: generated/
:template: class.rst

v2.CutMix
v2.MixUp

.. _functional_transforms:

Functional Transforms
Expand Down
152 changes: 152 additions & 0 deletions gallery/plot_cutmix_mixup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@

"""
===========================
How to use CutMix and MixUp
===========================
:class:`~torchvision.transforms.v2.Cutmix` and
:class:`~torchvision.transforms.v2.Mixup` are popular augmentation strategies
that can improve classification accuracy.
These transforms are slightly different from the rest of the Torchvision
transforms, because they expect
**batches** of samples as input, not individual images. In this example we'll
explain how to use them: after the ``DataLoader``, or as part of a collation
function.
"""

# %%
import torch
import torchvision
from torchvision.datasets import FakeData

# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()

from torchvision.transforms import v2


NUM_CLASSES = 100

# %%
# Pre-processing pipeline
# -----------------------
#
# We'll use a simple but typical image classification pipeline:

preproc = v2.Compose([
v2.PILToTensor(),
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomHorizontalFlip(p=0.5),
v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1]
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet
])

dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc)

img, label = dataset[0]
print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }")

# %%
#
# One important thing to note is that neither CutMix nor MixUp are part of this
# pre-processing pipeline. We'll add them a bit later once we define the
# DataLoader. Just as a refresher, this is what the DataLoader and training loop
# would look like if we weren't using CutMix or MixUp:

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

for images, labels in dataloader:
print(f"{images.shape = }, {labels.shape = }")
print(labels.dtype)
# <rest of the training loop here>
break
# %%

# %%
# Where to use MixUp and CutMix
# -----------------------------
#
# After the DataLoader
# ^^^^^^^^^^^^^^^^^^^^
#
# Now let's add CutMix and MixUp. The simplest way to do this right after the
# DataLoader: the Dataloader has already batched the images and labels for us,
# and this is exactly what these transforms expect as input:

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

cutmix = v2.Cutmix(num_classes=NUM_CLASSES)
mixup = v2.Mixup(num_classes=NUM_CLASSES)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])

for images, labels in dataloader:
print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }")
images, labels = cutmix_or_mixup(images, labels)
print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }")

# <rest of the training loop here>
break
# %%
#
# Note how the labels were also transformed: we went from a batched label of
# shape (batch_size,) to a tensor of shape (batch_size, num_classes). The
# transformed labels can still be passed as-is to a loss function like
# :func:`torch.nn.functional.cross_entropy`.
#
# As part of the collation function
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Passing the transforms after the DataLoader is the simplest way to use CutMix
# and MixUp, but one disadvantage is that it does not take advantage of the
# DataLoader multi-processing. For that, we can pass those transforms as part of
# the collation function (refer to the `PyTorch docs
# <https://pytorch.org/docs/stable/data.html#dataloader-collate-fn>`_ to learn
# more about collation).

from torch.utils.data import default_collate


def collate_fn(batch):
return cutmix_or_mixup(*default_collate(batch))


dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn)

for images, labels in dataloader:
print(f"{images.shape = }, {labels.shape = }")
# No need to call cutmix_or_mixup, it's already been called as part of the DataLoader!
# <rest of the training loop here>
break

# %%
# Non-standard input format
# -------------------------
#
# So far we've used a typical sample structure where we pass ``(images,
# labels)`` as inputs. MixUp and CutMix will magically work by default with most
# common sample structures: tuples where the second parameter is a tensor label,
# or dict with a "label[s]" key. Look at the documentation of the
# ``labels_getter`` parameter for more details.
#
# If your samples have a different structure, you can still use CutMix and MixUp
# by passing a callable to the ``labels_getter`` parameter. For example:

batch = {
"imgs": torch.rand(4, 3, 224, 224),
"target": {
"classes": torch.randint(0, NUM_CLASSES, size=(4,)),
"some_other_key": "this is going to be passed-through"
}
}


def labels_getter(batch):
return batch["target"]["classes"]


out = v2.Cutmix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch)
print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }")
18 changes: 9 additions & 9 deletions gallery/plot_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBox`
# * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask`
#
# How do I construct a datapoint?
Expand Down Expand Up @@ -76,11 +76,11 @@

########################################################################################################################
# In general, the datapoints can also store additional metadata that complements the underlying tensor. For example,
# :class:`~torchvision.datapoints.BoundingBox` stores the coordinate format as well as the spatial size of the
# :class:`~torchvision.datapoints.BoundingBoxes` stores the coordinate format as well as the spatial size of the
# corresponding image alongside the actual values:

bounding_box = datapoints.BoundingBox(
[17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
bounding_box = datapoints.BoundingBoxes(
[17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:]
)
print(bounding_box)

Expand All @@ -105,10 +105,10 @@ class PennFudanDataset(torch.utils.data.Dataset):
def __getitem__(self, item):
...

target["boxes"] = datapoints.BoundingBox(
target["boxes"] = datapoints.BoundingBoxes(
boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
canvas_size=F.get_size(img),
)
target["labels"] = labels
target["masks"] = datapoints.Mask(masks)
Expand All @@ -126,10 +126,10 @@ def __getitem__(self, item):

class WrapPennFudanDataset:
def __call__(self, img, target):
target["boxes"] = datapoints.BoundingBox(
target["boxes"] = datapoints.BoundingBoxes(
target["boxes"],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
canvas_size=F.get_size(img),
)
target["masks"] = datapoints.Mask(target["masks"])
return img, target
Expand All @@ -147,7 +147,7 @@ def get_transform(train):
########################################################################################################################
# .. note::
#
# If both :class:`~torchvision.datapoints.BoundingBox`'es and :class:`~torchvision.datapoints.Mask`'s are included in
# If both :class:`~torchvision.datapoints.BoundingBoxes`'es and :class:`~torchvision.datapoints.Mask`'s are included in
# the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or
# at least not wrapping the obsolete parts, can lead to a significant performance boost.
#
Expand Down
4 changes: 2 additions & 2 deletions gallery/plot_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def load_data():

masks = datapoints.Mask(merged_masks == labels.view(-1, 1, 1))

bounding_boxes = datapoints.BoundingBox(
masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
bounding_boxes = datapoints.BoundingBoxes(
masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:]
)

return path, image, bounding_boxes, masks, labels
Expand Down
11 changes: 4 additions & 7 deletions gallery/plot_transforms_v2_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

import pathlib
from collections import defaultdict

import PIL.Image

Expand All @@ -29,7 +28,7 @@ def show(sample):
image, target = sample
if isinstance(image, PIL.Image.Image):
image = F.to_image_tensor(image)
image = F.convert_dtype(image, torch.uint8)
image = F.to_dtype(image, torch.uint8, scale=True)
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)

fig, ax = plt.subplots()
Expand Down Expand Up @@ -99,20 +98,18 @@ def load_example_coco_detection_dataset(**kwargs):
transform = transforms.Compose(
[
transforms.RandomPhotometricDistort(),
transforms.RandomZoomOut(
fill=defaultdict(lambda: 0, {PIL.Image.Image: (123, 117, 104)})
),
transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}),
transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(),
transforms.ToImageTensor(),
transforms.ConvertImageDtype(torch.float32),
transforms.SanitizeBoundingBox(),
transforms.SanitizeBoundingBoxes(),
]
)

########################################################################################################################
# .. note::
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBox` transform is a no-op in this example, but it
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as
# the corresponding labels and optionally masks. It is particularly critical to add it if
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
Expand Down
Loading

0 comments on commit 3f82ee4

Please sign in to comment.