Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Mixup and Cutmix #4379

Merged
merged 19 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.utils.data
from torch.utils.data.dataloader import default_collate
from torch import nn
import torchvision
from torchvision.transforms.functional import InterpolationMode
Expand Down Expand Up @@ -165,10 +166,21 @@ def main(args):
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)

collate_fn = None
num_classes = len(dataset.classes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for now, but this exposes a limitation of our current datasets, which is that we don't consistently enforce a way of querying the number of classes in a dataset. The dataset refactoring work from @pmeier will address this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With #4432, you will be able to do

info = torchvision.datasets.info(name)
info.categories

where categories is a list of strings in which the index corresponds to the label.

mixup_transforms = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(torchvision.transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
if args.cutmix_alpha > 0.0:
mixup_transforms.append(torchvision.transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms, p=[0.5, 0.5])
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)

sampler=train_sampler, num_workers=args.workers, pin_memory=True,
collate_fn=collate_fn)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size,
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
Expand Down Expand Up @@ -273,6 +285,8 @@ def get_args_parser(add_help=True):
parser.add_argument('--label-smoothing', default=0.0, type=float,
help='label smoothing (default: 0.0)',
dest='label_smoothing')
parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)')
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm exposing here very few options (I'm using hardcoded p params). Keeping it simple until we use it in models to see what we want to support.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox nn.CrossEntropyLoss dose not work when you use mixup or cutmix, because the traget shape is (N, K), rather (N,)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was added at pytorch/pytorch#63122

This should be available on the latest stable version of pytorch. See doc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox ok, thanks.

parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
Expand Down
2 changes: 2 additions & 0 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def accuracy(output, target, topk=(1,)):
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
if target.ndim == 2:
target = target.max(dim=1)[1]

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
Expand Down
3 changes: 2 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,7 +1311,8 @@ def test_random_choice():
transforms.Resize(15),
transforms.Resize(20),
transforms.CenterCrop(10)
]
],
[1 / 3, 1 / 3, 1 / 3]
datumbox marked this conversation as resolved.
Show resolved Hide resolved
)
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250
Expand Down
77 changes: 77 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os
import torch
from torch._utils_internal import get_file_path_2
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.dataloader import default_collate
from torchvision import transforms as T
from torchvision.io import read_image
from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationMode

Expand Down Expand Up @@ -715,3 +719,76 @@ def test_gaussian_blur(device, meth_kwargs):
T.GaussianBlur, meth_kwargs=meth_kwargs,
test_exact_match=False, device=device, agg_method="max", tol=tol
)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('tranform', [T.RandomMixup, T.RandomCutmix])
@pytest.mark.parametrize('p', [0.0, 1.0])
@pytest.mark.parametrize('inplace', [True, False])
def test_random_mixupcutmix(device, tranform, p, inplace):
batch_size = 32
num_classes = 10
batch = torch.rand(batch_size, 3, 44, 56, device=device)
targets = torch.randint(num_classes, (batch_size, ), device=device, dtype=torch.int64)

fn = tranform(num_classes, p=p, inplace=inplace)
scripted_fn = torch.jit.script(fn)

seed = torch.seed()
output = fn(batch.clone(), targets.clone())

torch.manual_seed(seed)
output_scripted = scripted_fn(batch.clone(), targets.clone())
assert_equal(output[0], output_scripted[0])
assert_equal(output[1], output_scripted[1])

fn.__repr__()


@pytest.mark.parametrize('tranform', [T.RandomMixup, T.RandomCutmix])
def test_random_mixupcutmix_with_invalid_data(tranform):
with pytest.raises(AssertionError, match="Please provide a valid positive value for the num_classes."):
tranform(0)
with pytest.raises(AssertionError, match="Alpha param can't be zero."):
tranform(10, alpha=0.0)

t = tranform(10)
with pytest.raises(ValueError, match="Batch ndim should be 4."):
t(torch.rand(3, 60, 60), torch.randint(10, (1, )))
with pytest.raises(ValueError, match="Target ndim should be 1."):
t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, 1)))
with pytest.raises(ValueError, match="Target dtype should be torch.int64."):
t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, ), dtype=torch.int32))


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('transform, expected', [
(T.RandomMixup, [60.77401351928711, 0.5151033997535706]),
(T.RandomCutmix, [70.13909912109375, 0.525851309299469])
])
datumbox marked this conversation as resolved.
Show resolved Hide resolved
def test_random_mixupcutmix_with_real_data(device, transform, expected):
torch.manual_seed(12)

# Build dummy dataset
images = []
for test_file in [("encode_jpeg", "grace_hopper_517x606.jpg"), ("fakedata", "logos", "rgb_pytorch.png")]:
fullpath = (os.path.dirname(os.path.abspath(__file__)), 'assets') + test_file
img = read_image(get_file_path_2(*fullpath))
images.append(F.resize(img, [224, 224]))
dataset = TensorDataset(torch.stack(images).to(device=device, dtype=torch.float32),
torch.tensor([0, 1], device=device))

# Use mixup in the collate
trans = transform(2)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=lambda batch: trans(*default_collate(batch)))

# Test against known statistics about the produced images
stats = []
for _ in range(25):
for b, t in dataloader:
stats.append(torch.stack([b.std(), t.std()]))

torch.testing.assert_close(
torch.stack(stats).mean(dim=0),
torch.tensor(expected)
)
182 changes: 178 additions & 4 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
"RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"]
"RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize", 'RandomMixup',
"RandomCutmix"]


class Compose:
Expand Down Expand Up @@ -515,9 +516,20 @@ def __call__(self, img):
class RandomChoice(RandomTransforms):
"""Apply single transformation randomly picked from a list. This transform does not support torchscript.
"""
def __call__(self, img):
t = random.choice(self.transforms)
return t(img)
def __init__(self, transforms, p=None):
super().__init__(transforms)
if p is not None and not isinstance(p, Sequence):
raise TypeError("Argument transforms should be a sequence")
self.p = p

def __call__(self, *args):
t = random.choices(self.transforms, weights=self.p)[0]
return t(*args)
datumbox marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
format_string = super().__repr__()
format_string += '(p={0})'.format(self.p)
return format_string


class RandomCrop(torch.nn.Module):
Expand Down Expand Up @@ -1953,3 +1965,165 @@ def forward(self, img):

def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)


# TODO: move this to references before merging and delete the tests
class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.

Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for mixup.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""

def __init__(self, num_classes: int,
p: float = 0.5, alpha: float = 1.0,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the default to p=0.5 to keep it consistent with other transforms.

inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."

self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace

def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )

Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
elif target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
elif target.dtype != torch.int64:
raise ValueError("Target dtype should be torch.int64. Got {}".format(target.dtype))

if not self.inplace:
batch = batch.clone()
# target = target.clone()

target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32)
datumbox marked this conversation as resolved.
Show resolved Hide resolved
if torch.rand(1).item() >= self.p:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1)

# Implemented as on mixup paper, page 3.
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
batch_rolled.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_rolled)

target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)

return batch, target

def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_classes={num_classes}'
s += ', p={p}'
s += ', alpha={alpha}'
s += ', inplace={inplace}'
s += ')'
return s.format(**self.__dict__)


class RandomCutmix(torch.nn.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some code duplication and thus bits that can be shared across the 2 classes, but this will be fixed on the new API with proper class inheritance.

"""Randomly apply Cutmix to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_.

Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for cutmix.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""

def __init__(self, num_classes: int,
p: float = 0.5, alpha: float = 1.0,
inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."

self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace

def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )

Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
elif target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
elif target.dtype != torch.int64:
raise ValueError("Target dtype should be torch.int64. Got {}".format(target.dtype))

if not self.inplace:
batch = batch.clone()
# target = target.clone()

target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32)
if torch.rand(1).item() >= self.p:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1)

# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
W, H = F.get_image_size(batch)

r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,))

r = 0.5 * math.sqrt(1.0 - lambda_param)
r_w_half = int(r * W)
r_h_half = int(r * H)

x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))

batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)

return batch, target

def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_classes={num_classes}'
s += ', p={p}'
s += ', alpha={alpha}'
s += ', inplace={inplace}'
s += ')'
return s.format(**self.__dict__)