Skip to content

Commit

Permalink
Add gallery example for MixUp and CutMix (#7772)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Jul 31, 2023
1 parent 8d4e879 commit 9b4ec8d
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 13 deletions.
8 changes: 4 additions & 4 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,13 @@ The new transform can be used standalone or mixed-and-matched with existing tran
AugMix
v2.AugMix

Cutmix - Mixup
CutMix - MixUp
--------------

Cutmix and Mixup are special transforms that
CutMix and MixUp are special transforms that
are meant to be used on batches rather than on individual images, because they
are combining pairs of images together. These can be used after the dataloader,
or part of a collation function. See
are combining pairs of images together. These can be used after the dataloader
(once the samples are batched), or part of a collation function. See
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.

.. autosummary::
Expand Down
148 changes: 146 additions & 2 deletions gallery/plot_cutmix_mixup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,152 @@

"""
===========================
How to use Cutmix and Mixup
How to use CutMix and MixUp
===========================
TODO
:class:`~torchvision.transforms.v2.Cutmix` and
:class:`~torchvision.transforms.v2.Mixup` are popular augmentation strategies
that can improve classification accuracy.
These transforms are slightly different from the rest of the Torchvision
transforms, because they expect
**batches** of samples as input, not individual images. In this example we'll
explain how to use them: after the ``DataLoader``, or as part of a collation
function.
"""

# %%
import torch
import torchvision
from torchvision.datasets import FakeData

# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()

from torchvision.transforms import v2


NUM_CLASSES = 100

# %%
# Pre-processing pipeline
# -----------------------
#
# We'll use a simple but typical image classification pipeline:

preproc = v2.Compose([
v2.PILToTensor(),
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomHorizontalFlip(p=0.5),
v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1]
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet
])

dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc)

img, label = dataset[0]
print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }")

# %%
#
# One important thing to note is that neither CutMix nor MixUp are part of this
# pre-processing pipeline. We'll add them a bit later once we define the
# DataLoader. Just as a refresher, this is what the DataLoader and training loop
# would look like if we weren't using CutMix or MixUp:

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

for images, labels in dataloader:
print(f"{images.shape = }, {labels.shape = }")
print(labels.dtype)
# <rest of the training loop here>
break
# %%

# %%
# Where to use MixUp and CutMix
# -----------------------------
#
# After the DataLoader
# ^^^^^^^^^^^^^^^^^^^^
#
# Now let's add CutMix and MixUp. The simplest way to do this right after the
# DataLoader: the Dataloader has already batched the images and labels for us,
# and this is exactly what these transforms expect as input:

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

cutmix = v2.Cutmix(num_classes=NUM_CLASSES)
mixup = v2.Mixup(num_classes=NUM_CLASSES)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])

for images, labels in dataloader:
print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }")
images, labels = cutmix_or_mixup(images, labels)
print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }")

# <rest of the training loop here>
break
# %%
#
# Note how the labels were also transformed: we went from a batched label of
# shape (batch_size,) to a tensor of shape (batch_size, num_classes). The
# transformed labels can still be passed as-is to a loss function like
# :func:`torch.nn.functional.cross_entropy`.
#
# As part of the collation function
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Passing the transforms after the DataLoader is the simplest way to use CutMix
# and MixUp, but one disadvantage is that it does not take advantage of the
# DataLoader multi-processing. For that, we can pass those transforms as part of
# the collation function (refer to the `PyTorch docs
# <https://pytorch.org/docs/stable/data.html#dataloader-collate-fn>`_ to learn
# more about collation).

from torch.utils.data import default_collate


def collate_fn(batch):
return cutmix_or_mixup(*default_collate(batch))


dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn)

for images, labels in dataloader:
print(f"{images.shape = }, {labels.shape = }")
# No need to call cutmix_or_mixup, it's already been called as part of the DataLoader!
# <rest of the training loop here>
break

# %%
# Non-standard input format
# -------------------------
#
# So far we've used a typical sample structure where we pass ``(images,
# labels)`` as inputs. MixUp and CutMix will magically work by default with most
# common sample structures: tuples where the second parameter is a tensor label,
# or dict with a "label[s]" key. Look at the documentation of the
# ``labels_getter`` parameter for more details.
#
# If your samples have a different structure, you can still use CutMix and MixUp
# by passing a callable to the ``labels_getter`` parameter. For example:

batch = {
"imgs": torch.rand(4, 3, 224, 224),
"target": {
"classes": torch.randint(0, NUM_CLASSES, size=(4,)),
"some_other_key": "this is going to be passed-through"
}
}


def labels_getter(batch):
return batch["target"]["classes"]


out = v2.Cutmix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch)
print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }")
2 changes: 1 addition & 1 deletion test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -1922,7 +1922,7 @@ def test_supported_input_structure(self, T):

dataset = self.DummyDataset(size=batch_size, num_classes=num_classes)

cutmix_mixup = T(alpha=0.5, num_classes=num_classes)
cutmix_mixup = T(num_classes=num_classes)

dl = DataLoader(dataset, batch_size=batch_size)

Expand Down
26 changes: 20 additions & 6 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def _transform(


class _BaseMixupCutmix(Transform):
def __init__(self, *, alpha: float = 1, num_classes: int, labels_getter="default") -> None:
def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None:
super().__init__()
self.alpha = alpha
self.alpha = float(alpha)
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))

self.num_classes = num_classes
Expand Down Expand Up @@ -204,13 +204,20 @@ def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:


class Mixup(_BaseMixupCutmix):
"""[BETA] Apply Mixup to the provided batch of images and labels.
"""[BETA] Apply MixUp to the provided batch of images and labels.
.. v2betastatus:: Mixup transform
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
.. note::
This transform is meant to be used on **batches** of samples, not
individual images. See
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage
examples.
The sample pairing is deterministic and done by matching consecutive
samples in the batch, so the batch needs to be shuffled (this is an
implementation detail, not a guaranteed convention.)
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
into a tensor of shape ``(batch_size, num_classes)``.
Expand Down Expand Up @@ -246,14 +253,21 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class Cutmix(_BaseMixupCutmix):
"""[BETA] Apply Cutmix to the provided batch of images and labels.
"""[BETA] Apply CutMix to the provided batch of images and labels.
.. v2betastatus:: Cutmix transform
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
<https://arxiv.org/abs/1905.04899>`_.
See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
.. note::
This transform is meant to be used on **batches** of samples, not
individual images. See
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage
examples.
The sample pairing is deterministic and done by matching consecutive
samples in the batch, so the batch needs to be shuffled (this is an
implementation detail, not a guaranteed convention.)
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
into a tensor of shape ``(batch_size, num_classes)``.
Expand Down

0 comments on commit 9b4ec8d

Please sign in to comment.